LitBench-UI / src /tasks /gen_related_work.py
Andreas99's picture
Upload 22 files
908351f verified
"""
Generates the related work section for a given paper.
The input
- The input prompt is a string that contains the information of the paper for which the related work section needs to be generated.
- The input prompt should be in the following format:
Title of Paper: <title of the paper>
Abstract of Paper: <abstract of the paper>
The output
- The output is a string that contains the related work section for the given paper.
"""
import torch
import json
import networkx as nx
import numpy as np
from tqdm import tqdm
from peft import PeftModel
from transformers import (AutoModel, AutoTokenizer, AutoModelForCausalLM, pipeline)
from tqdm import tqdm
import re
import pandas as pd
import os
from sklearn.metrics.pairwise import cosine_similarity
from utils.utils import read_yaml_file
import datetime
class LitFM():
def __init__(self, graph_path, adapter_path):
self.graph_name = graph_path.split('.')[0].split('/')[-1] if '/' in graph_path else graph_path.split('.')[0]
self.batch_size = 32
self.neigh_num = 4
config = read_yaml_file('configs/config.yaml')
retrieval_graph_path = graph_path
self.pretrained_model = config['retriever']['embedder']
# define generation model
model_path = config['inference']["base_model"]
self.generation_tokenizer = AutoTokenizer.from_pretrained(model_path)
self.generation_tokenizer.model_max_length = 2048
if self.generation_tokenizer.pad_token is None:
self.generation_tokenizer.pad_token = self.generation_tokenizer.eos_token
self.generation_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto")
self.generation_model = PeftModel.from_pretrained(self.generation_model, adapter_path, adapter_name="instruction", torch_dtype=torch.float16)
self.model_pipeline = None
if self.generation_tokenizer.chat_template is not None:
self.model_pipeline = pipeline(
"text-generation",
model=model_path,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
)
# define instruction models
self.instruction_pipe = pipeline(
"text-generation",
model=config["inference"]["gen_related_work_instruct_model"],
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
)
# load graph data for retrieval
def translate_graph(graph):
all_nodes = list(graph.nodes())
raw_id_2_id_dict = {}
id_2_raw_id_dict = {}
num = 0
for node in all_nodes:
raw_id_2_id_dict[node] = num
id_2_raw_id_dict[num] = node
num += 1
return raw_id_2_id_dict, id_2_raw_id_dict
whole_graph_data_raw = nx.read_gexf(retrieval_graph_path, node_type=None, relabel=False, version='1.2draft')
self.whole_graph_raw_id_2_id_dict, self.whole_graph_id_2_raw_id_dict = translate_graph(whole_graph_data_raw)
self.whole_graph_id_2_title_abs = dict()
for paper_id in whole_graph_data_raw.nodes():
title = whole_graph_data_raw.nodes()[paper_id]['title']
abstract = whole_graph_data_raw.nodes()[paper_id]['abstract']
self.whole_graph_id_2_title_abs[self.whole_graph_raw_id_2_id_dict[paper_id]] = [title, abstract]
# define prompt template
template_file_path = 'configs/alpaca.json'
with open(template_file_path) as fp:
self.template = json.load(fp)
self.human_instruction = ['### Input:', '### Response:']
def _generate_retrieval_prompt(self, data_point: dict):
instruction = "Please select the paper that is more likely to be cited by the paper from the list of candidate papers. Your answer MUST be **only the exact title** of the selected paper without generating ANY other text or section. Your answer MUST belong to the list of candidate papers.\n"
prompt_input = ""
prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
prompt_input = prompt_input + "candidate papers: " + "\n"
for i in range(len(data_point['nei_titles'])):
prompt_input = prompt_input + str(i) + '. ' + data_point['nei_titles'][i] + "\n"
if self.model_pipeline is not None:
res = [
{"role": "system", "content": instruction},
{"role": "user", "content": prompt_input},
]
else:
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
return res
def _generate_sentence_prompt(self, data_point):
instruction = "Please generate the citation sentence of how the Paper cites paper B in its related work section."
prompt_input = ""
prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n"
prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n"
if self.model_pipeline is not None:
res = [
{"role": "system", "content": instruction},
{"role": "user", "content": prompt_input},
]
else:
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
return res
def _generate_topic_prompt(self, data_point):
prompt_input = ""
prompt_input = prompt_input + "Here are the information of the paper: \n"
prompt_input = prompt_input + data_point['usr_prompt'] + '\n'
prompt_input = prompt_input + "Directlty give me the topics you select.\n"
res = [
{"role": "system", "content": "I need to write the related work section for this paper. Could you suggest three most relevant topics to discuss in the related work section? Your answer should be strictly one topic after the other line by line with nothing else being generated and no further explanation/information.\n"},
{"role": "user", "content": prompt_input},
]
return res
def _generate_paragraph_prompt(self, data_point):
prompt_input = ""
prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
prompt_input = prompt_input + "Topic of this paragraph: " + data_point['topic'] + "\n"
prompt_input = prompt_input + "Papers that should be cited in paragraph: \n"
i = data_point['paper_citation_indicator']
for paper_idx in range(len(data_point['nei_title'])):
prompt_input = prompt_input + "[" + str(i) + "]. " + data_point['nei_title'][paper_idx][0] + '.' + " Citation sentence of this paper in the paragraph: " + data_point['nei_sentence'][paper_idx] + '\n'
i += 1
prompt_input = prompt_input + "All the above cited papers should be included and each cited paper should be indicated with its index number. Note that you should not include the title of any paper\n"
res = [
{"role": "system", "content": "Please write a paragraph that review the research relationships between this paper and other cited papers.\n"},
{"role": "user", "content": prompt_input},
]
return res
def _generate_summary_prompt(self, data_point):
prompt_input = ""
prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
prompt_input = prompt_input + "Paragraphs that should be combined: " + "\n"
i = 1
for para in data_point['paragraphs']:
prompt_input = prompt_input + " Paragraph " + str(i) + ": " + para + '\n'
i += 1
res = [
{"role": "system", "content": "Please combine the following paragraphs in a cohenrent way that also keeps the citations and make the flow between paragraphs more smoothly\nAdd a sentence at the beginning of each paragraph to clarify its connection to the previous ones. Do not include any other surrounding text and not add a references list at all\n"},
{"role": "user", "content": prompt_input},
]
return res
@staticmethod
def generate_text(prompt, tokenizer, model, temperature, top_p, repetition_penalty, max_new_tokens):
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
output = model.generate(
**inputs,
do_sample=True,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
max_new_tokens=max_new_tokens,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
)
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
return output_text
def get_llm_response(self, prompt, model_type):
self.generation_model.set_adapter('instruction')
if model_type == 'zeroshot':
raw_output = self.instruction_pipe(
prompt,
max_new_tokens=8096,
temperature=0.9,
top_p=0.95,
repetition_penalty=1.15,
)[0]['generated_text'][-1]
if model_type == 'zeroshot_short':
raw_output = self.instruction_pipe(
prompt,
max_new_tokens=256,
temperature=0.9,
top_p=0.95,
repetition_penalty=1.15,
)[0]['generated_text'][-1]
if model_type == 'instruction':
self.generation_model.set_adapter('instruction')
if self.model_pipeline is not None:
raw_output = self.model_pipeline(
prompt,
temperature=0.9,
top_p=0.95,
repetition_penalty=1.15,
)[0]['generated_text'][-1]
else:
raw_output = self.generate_text(
prompt,
self.generation_tokenizer,
self.generation_model,
temperature=0.9,
top_p=0.95,
repetition_penalty=1.15,
max_new_tokens=256,
)
return raw_output
def single_paper_sentence_test(self, usr_prompt, t_title, t_abs):
datapoint = {'usr_prompt':usr_prompt, 't_title':t_title, 't_abs':t_abs}
prompt = self._generate_sentence_prompt(datapoint)
ans = self.get_llm_response(prompt, 'instruction')
res = ans.strip().split(self.human_instruction[1])[-1]
return res
def single_paper_retrieval_test(self, usr_prompt, candidates):
datapoint = {'usr_prompt':usr_prompt, 'nei_titles':list(candidates), 't_title': ''}
prompt = self._generate_retrieval_prompt(datapoint)
ans = self.get_llm_response(prompt, 'instruction')
res = ans.strip().split(self.human_instruction[1])[-1]
return res
def single_paper_topic_test(self, usr_prompt):
datapoint = {'usr_prompt': usr_prompt}
prompt = self._generate_topic_prompt(datapoint)
ans = self.get_llm_response(prompt, 'zeroshot_short')
res = ans['content']
res = res.replace('\n\n', '\n')
return res
def retrieval_for_one_query(self, id_2_title_abs, prompt):
if os.path.exists(f'datasets/{self.graph_name}_embeddings.parquet'):
all_query_embs = torch.tensor(np.array(pd.read_parquet(f'datasets/{self.graph_name}_embeddings.parquet')))
else:
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-en-v1.5")
model = AutoModel.from_pretrained("BAAI/bge-large-en-v1.5").to(device='cuda', dtype=torch.float16)
model.eval()
paper_list = list(id_2_title_abs.keys())
all_query_embs = torch.zeros(len(paper_list), 1024)
i = 0
batch_size = 200
candidate_emb_list = []
pbar = tqdm(total=len(paper_list))
while i < len(paper_list):
paper_batch = paper_list[i:i+batch_size]
paper_text_batch = []
for paper_id in paper_batch:
prompt = id_2_title_abs[paper_id][0] + id_2_title_abs[paper_id][1]
paper_text_batch.append(prompt)
inputs = tokenizer(paper_text_batch, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs.to('cuda'))
candidate_embeddings = outputs.last_hidden_state[:, 0, :].cpu()
candidate_embeddings = candidate_embeddings.reshape(-1, 1024)
candidate_emb_list.append(candidate_embeddings)
i += len(candidate_embeddings)
pbar.update(len(candidate_embeddings))
all_query_embs = torch.cat(candidate_emb_list, 0)
pd.DataFrame(all_query_embs.numpy()).to_parquet(f'datasets/{self.graph_name}_embeddings.parquet')
# get the embeddings of the prompt
pretrained_model_name = self.pretrained_model
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
LLM_model = AutoModel.from_pretrained(pretrained_model_name).cuda()
LLM_model.eval()
encoded_input = tokenizer([prompt], padding = True, truncation=True, max_length=512 , return_tensors='pt')
with torch.no_grad():
output = LLM_model(**encoded_input.to('cuda'), output_hidden_states=True).hidden_states[-1]
sentence_embedding = output[:, 0, :]
tmp_scores = cosine_similarity(sentence_embedding.to("cpu"), all_query_embs.to("cpu"))[0]
_, idxs = torch.sort(torch.tensor(tmp_scores), descending=True)
top_10 = [int(k) for k in idxs[:10]]
return [id_2_title_abs[i][0] for i in top_10], [self.whole_graph_id_2_raw_id_dict[i] for i in top_10]
def single_paper_related_work_generation(self, usr_prompt):
citation_papers = []
nei_sentence = []
# Get topics
retrieval_query = self.single_paper_topic_test(usr_prompt)
# Split topics
topic_num = 3
try:
split_topics = retrieval_query.strip().split('\n')
if split_topics[0] == '':
split_topics = split_topics[1:]
split_topics = split_topics[:topic_num]
except:
split_topics = retrieval_query.strip().split(':')
split_topics = split_topics.strip().split(';')
split_topics = split_topics[:topic_num]
if len(split_topics) > topic_num:
return ["too many topics", split_topics]
# Get top-5 papers for each topic
for retrieval_query in split_topics:
# retrieve papers
candidate_citation_papers, candidate_raw_ids = self.retrieval_for_one_query(self.whole_graph_id_2_title_abs, retrieval_query)
topic_specific_citation_papers = []
# select top-5 papers
for _ in range(5):
# picking most likely to be cited paper
selected_paper = self.single_paper_retrieval_test(usr_prompt, candidate_citation_papers).replace(' \n','').replace('\n','')
words = selected_paper.strip().split(' ')
index = -1
for w in words:
try:
index = int(w)
except:
pass
if index != -1 and index < len(candidate_citation_papers):
paper_title = candidate_citation_papers[index]
candidate_citation_papers = list(set(candidate_citation_papers) - set([paper_title]))
topic_specific_citation_papers.append([paper_title, candidate_raw_ids[index]])
else:
for i, paper_title in enumerate(list(candidate_citation_papers)):
if paper_title.lower().replace(' ', '') in selected_paper.lower().replace(' ', '') or selected_paper.lower().replace(' ', '') in paper_title.lower().replace(' ', ''):
candidate_citation_papers = list(set(candidate_citation_papers) - set([paper_title]))
topic_specific_citation_papers.append([paper_title, candidate_raw_ids[i]])
break
citation_papers.append(topic_specific_citation_papers)
# Remove empty lists
citation_papers = [x for x in citation_papers if x != []]
# Generate citation sentences
for topic_idx in range(len(citation_papers)):
topic_specific_nei_sentence = []
for paper_idx in range(len(citation_papers[topic_idx])):
sentence = self.single_paper_sentence_test(usr_prompt, citation_papers[topic_idx][paper_idx][0], "")
# Match \cite{...}
sentence = re.sub(r'\\cite\{[^{}]+\}', "", sentence)
topic_specific_nei_sentence.append(sentence)
nei_sentence.append(topic_specific_nei_sentence)
# Generate paragraphs
paragraphs = []
references = [] # Store references for citation
paper_citation_indicator = 1 # Indicator for citation paper
for topic_idx in range(len(citation_papers)):
datapoint = {'usr_prompt': usr_prompt,
'nei_title': citation_papers[topic_idx],
'nei_sentence': nei_sentence[topic_idx],
'topic': split_topics[topic_idx],
'paper_citation_indicator': paper_citation_indicator}
prompt = self._generate_paragraph_prompt(datapoint)
ans = self.get_llm_response(prompt, 'zeroshot')
res = ans['content']
paragraphs.append(res)
# Store referencess
for ref_idx, paper in enumerate(citation_papers[topic_idx]):
# Extract year and month from raw_id
raw_id = re.sub(r'[a-zA-Z/]+', '', paper[1])
year = raw_id[:2]
year = '19' + year if int(year) > 70 else '20' + year
month = datetime.date(1900, int(raw_id[2:4]), 1).strftime('%B')
references.append(f"[{paper_citation_indicator + ref_idx}] {paper[0]}, arXiv {raw_id}, {month} {year}")
# Update paper_citation_indicator
paper_citation_indicator = paper_citation_indicator + len(nei_sentence[topic_idx])
# Generate summary
datapoint = {'usr_prompt': usr_prompt, 'paragraphs': paragraphs}
prompt = self._generate_summary_prompt(datapoint)
ans = self.get_llm_response(prompt, 'zeroshot')
summary = ans['content']
# Append references to summary
summary_with_references = summary + "\n\n### References\n" + "\n".join(references)
return summary_with_references
def gen_related_work(message, graph_path, adapter_path):
litfm_instance = LitFM(graph_path, adapter_path)
return litfm_instance.single_paper_related_work_generation(message)