""" Generates the related work section for a given paper. The input - The input prompt is a string that contains the information of the paper for which the related work section needs to be generated. - The input prompt should be in the following format: Title of Paper: Abstract of Paper: <abstract of the paper> The output - The output is a string that contains the related work section for the given paper. """ import torch import json import networkx as nx import numpy as np from tqdm import tqdm from peft import PeftModel from transformers import (AutoModel, AutoTokenizer, AutoModelForCausalLM, pipeline) from tqdm import tqdm import re import pandas as pd import os from sklearn.metrics.pairwise import cosine_similarity from utils.utils import read_yaml_file import datetime class LitFM(): def __init__(self, graph_path, adapter_path): self.graph_name = graph_path.split('.')[0].split('/')[-1] if '/' in graph_path else graph_path.split('.')[0] self.batch_size = 32 self.neigh_num = 4 config = read_yaml_file('configs/config.yaml') retrieval_graph_path = graph_path self.pretrained_model = config['retriever']['embedder'] # define generation model model_path = config['inference']["base_model"] self.generation_tokenizer = AutoTokenizer.from_pretrained(model_path) self.generation_tokenizer.model_max_length = 2048 if self.generation_tokenizer.pad_token is None: self.generation_tokenizer.pad_token = self.generation_tokenizer.eos_token self.generation_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto") self.generation_model = PeftModel.from_pretrained(self.generation_model, adapter_path, adapter_name="instruction", torch_dtype=torch.float16) self.model_pipeline = None if self.generation_tokenizer.chat_template is not None: self.model_pipeline = pipeline( "text-generation", model=model_path, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", ) # define instruction models self.instruction_pipe = pipeline( "text-generation", model=config["inference"]["gen_related_work_instruct_model"], model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", ) # load graph data for retrieval def translate_graph(graph): all_nodes = list(graph.nodes()) raw_id_2_id_dict = {} id_2_raw_id_dict = {} num = 0 for node in all_nodes: raw_id_2_id_dict[node] = num id_2_raw_id_dict[num] = node num += 1 return raw_id_2_id_dict, id_2_raw_id_dict whole_graph_data_raw = nx.read_gexf(retrieval_graph_path, node_type=None, relabel=False, version='1.2draft') self.whole_graph_raw_id_2_id_dict, self.whole_graph_id_2_raw_id_dict = translate_graph(whole_graph_data_raw) self.whole_graph_id_2_title_abs = dict() for paper_id in whole_graph_data_raw.nodes(): title = whole_graph_data_raw.nodes()[paper_id]['title'] abstract = whole_graph_data_raw.nodes()[paper_id]['abstract'] self.whole_graph_id_2_title_abs[self.whole_graph_raw_id_2_id_dict[paper_id]] = [title, abstract] # define prompt template template_file_path = 'configs/alpaca.json' with open(template_file_path) as fp: self.template = json.load(fp) self.human_instruction = ['### Input:', '### Response:'] def _generate_retrieval_prompt(self, data_point: dict): instruction = "Please select the paper that is more likely to be cited by the paper from the list of candidate papers. Your answer MUST be **only the exact title** of the selected paper without generating ANY other text or section. Your answer MUST belong to the list of candidate papers.\n" prompt_input = "" prompt_input = prompt_input + data_point['usr_prompt'] + "\n" prompt_input = prompt_input + "candidate papers: " + "\n" for i in range(len(data_point['nei_titles'])): prompt_input = prompt_input + str(i) + '. ' + data_point['nei_titles'][i] + "\n" if self.model_pipeline is not None: res = [ {"role": "system", "content": instruction}, {"role": "user", "content": prompt_input}, ] else: res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) return res def _generate_sentence_prompt(self, data_point): instruction = "Please generate the citation sentence of how the Paper cites paper B in its related work section." prompt_input = "" prompt_input = prompt_input + data_point['usr_prompt'] + "\n" prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n" prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n" if self.model_pipeline is not None: res = [ {"role": "system", "content": instruction}, {"role": "user", "content": prompt_input}, ] else: res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input) return res def _generate_topic_prompt(self, data_point): prompt_input = "" prompt_input = prompt_input + "Here are the information of the paper: \n" prompt_input = prompt_input + data_point['usr_prompt'] + '\n' prompt_input = prompt_input + "Directlty give me the topics you select.\n" res = [ {"role": "system", "content": "I need to write the related work section for this paper. Could you suggest three most relevant topics to discuss in the related work section? Your answer should be strictly one topic after the other line by line with nothing else being generated and no further explanation/information.\n"}, {"role": "user", "content": prompt_input}, ] return res def _generate_paragraph_prompt(self, data_point): prompt_input = "" prompt_input = prompt_input + data_point['usr_prompt'] + "\n" prompt_input = prompt_input + "Topic of this paragraph: " + data_point['topic'] + "\n" prompt_input = prompt_input + "Papers that should be cited in paragraph: \n" i = data_point['paper_citation_indicator'] for paper_idx in range(len(data_point['nei_title'])): prompt_input = prompt_input + "[" + str(i) + "]. " + data_point['nei_title'][paper_idx][0] + '.' + " Citation sentence of this paper in the paragraph: " + data_point['nei_sentence'][paper_idx] + '\n' i += 1 prompt_input = prompt_input + "All the above cited papers should be included and each cited paper should be indicated with its index number. Note that you should not include the title of any paper\n" res = [ {"role": "system", "content": "Please write a paragraph that review the research relationships between this paper and other cited papers.\n"}, {"role": "user", "content": prompt_input}, ] return res def _generate_summary_prompt(self, data_point): prompt_input = "" prompt_input = prompt_input + data_point['usr_prompt'] + "\n" prompt_input = prompt_input + "Paragraphs that should be combined: " + "\n" i = 1 for para in data_point['paragraphs']: prompt_input = prompt_input + " Paragraph " + str(i) + ": " + para + '\n' i += 1 res = [ {"role": "system", "content": "Please combine the following paragraphs in a cohenrent way that also keeps the citations and make the flow between paragraphs more smoothly\nAdd a sentence at the beginning of each paragraph to clarify its connection to the previous ones. Do not include any other surrounding text and not add a references list at all\n"}, {"role": "user", "content": prompt_input}, ] return res @staticmethod def generate_text(prompt, tokenizer, model, temperature, top_p, repetition_penalty, max_new_tokens): inputs = tokenizer(prompt, return_tensors="pt").to("cuda") with torch.no_grad(): output = model.generate( **inputs, do_sample=True, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, use_cache=True, ) output_text = tokenizer.decode(output[0], skip_special_tokens=True) return output_text def get_llm_response(self, prompt, model_type): self.generation_model.set_adapter('instruction') if model_type == 'zeroshot': raw_output = self.instruction_pipe( prompt, max_new_tokens=8096, temperature=0.9, top_p=0.95, repetition_penalty=1.15, )[0]['generated_text'][-1] if model_type == 'zeroshot_short': raw_output = self.instruction_pipe( prompt, max_new_tokens=256, temperature=0.9, top_p=0.95, repetition_penalty=1.15, )[0]['generated_text'][-1] if model_type == 'instruction': self.generation_model.set_adapter('instruction') if self.model_pipeline is not None: raw_output = self.model_pipeline( prompt, temperature=0.9, top_p=0.95, repetition_penalty=1.15, )[0]['generated_text'][-1] else: raw_output = self.generate_text( prompt, self.generation_tokenizer, self.generation_model, temperature=0.9, top_p=0.95, repetition_penalty=1.15, max_new_tokens=256, ) return raw_output def single_paper_sentence_test(self, usr_prompt, t_title, t_abs): datapoint = {'usr_prompt':usr_prompt, 't_title':t_title, 't_abs':t_abs} prompt = self._generate_sentence_prompt(datapoint) ans = self.get_llm_response(prompt, 'instruction') res = ans.strip().split(self.human_instruction[1])[-1] return res def single_paper_retrieval_test(self, usr_prompt, candidates): datapoint = {'usr_prompt':usr_prompt, 'nei_titles':list(candidates), 't_title': ''} prompt = self._generate_retrieval_prompt(datapoint) ans = self.get_llm_response(prompt, 'instruction') res = ans.strip().split(self.human_instruction[1])[-1] return res def single_paper_topic_test(self, usr_prompt): datapoint = {'usr_prompt': usr_prompt} prompt = self._generate_topic_prompt(datapoint) ans = self.get_llm_response(prompt, 'zeroshot_short') res = ans['content'] res = res.replace('\n\n', '\n') return res def retrieval_for_one_query(self, id_2_title_abs, prompt): if os.path.exists(f'datasets/{self.graph_name}_embeddings.parquet'): all_query_embs = torch.tensor(np.array(pd.read_parquet(f'datasets/{self.graph_name}_embeddings.parquet'))) else: tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-en-v1.5") model = AutoModel.from_pretrained("BAAI/bge-large-en-v1.5").to(device='cuda', dtype=torch.float16) model.eval() paper_list = list(id_2_title_abs.keys()) all_query_embs = torch.zeros(len(paper_list), 1024) i = 0 batch_size = 200 candidate_emb_list = [] pbar = tqdm(total=len(paper_list)) while i < len(paper_list): paper_batch = paper_list[i:i+batch_size] paper_text_batch = [] for paper_id in paper_batch: prompt = id_2_title_abs[paper_id][0] + id_2_title_abs[paper_id][1] paper_text_batch.append(prompt) inputs = tokenizer(paper_text_batch, return_tensors='pt', padding=True, truncation=True) with torch.no_grad(): outputs = model(**inputs.to('cuda')) candidate_embeddings = outputs.last_hidden_state[:, 0, :].cpu() candidate_embeddings = candidate_embeddings.reshape(-1, 1024) candidate_emb_list.append(candidate_embeddings) i += len(candidate_embeddings) pbar.update(len(candidate_embeddings)) all_query_embs = torch.cat(candidate_emb_list, 0) pd.DataFrame(all_query_embs.numpy()).to_parquet(f'datasets/{self.graph_name}_embeddings.parquet') # get the embeddings of the prompt pretrained_model_name = self.pretrained_model tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) LLM_model = AutoModel.from_pretrained(pretrained_model_name).cuda() LLM_model.eval() encoded_input = tokenizer([prompt], padding = True, truncation=True, max_length=512 , return_tensors='pt') with torch.no_grad(): output = LLM_model(**encoded_input.to('cuda'), output_hidden_states=True).hidden_states[-1] sentence_embedding = output[:, 0, :] tmp_scores = cosine_similarity(sentence_embedding.to("cpu"), all_query_embs.to("cpu"))[0] _, idxs = torch.sort(torch.tensor(tmp_scores), descending=True) top_10 = [int(k) for k in idxs[:10]] return [id_2_title_abs[i][0] for i in top_10], [self.whole_graph_id_2_raw_id_dict[i] for i in top_10] def single_paper_related_work_generation(self, usr_prompt): citation_papers = [] nei_sentence = [] # Get topics retrieval_query = self.single_paper_topic_test(usr_prompt) # Split topics topic_num = 3 try: split_topics = retrieval_query.strip().split('\n') if split_topics[0] == '': split_topics = split_topics[1:] split_topics = split_topics[:topic_num] except: split_topics = retrieval_query.strip().split(':') split_topics = split_topics.strip().split(';') split_topics = split_topics[:topic_num] if len(split_topics) > topic_num: return ["too many topics", split_topics] # Get top-5 papers for each topic for retrieval_query in split_topics: # retrieve papers candidate_citation_papers, candidate_raw_ids = self.retrieval_for_one_query(self.whole_graph_id_2_title_abs, retrieval_query) topic_specific_citation_papers = [] # select top-5 papers for _ in range(5): # picking most likely to be cited paper selected_paper = self.single_paper_retrieval_test(usr_prompt, candidate_citation_papers).replace(' \n','').replace('\n','') words = selected_paper.strip().split(' ') index = -1 for w in words: try: index = int(w) except: pass if index != -1 and index < len(candidate_citation_papers): paper_title = candidate_citation_papers[index] candidate_citation_papers = list(set(candidate_citation_papers) - set([paper_title])) topic_specific_citation_papers.append([paper_title, candidate_raw_ids[index]]) else: for i, paper_title in enumerate(list(candidate_citation_papers)): if paper_title.lower().replace(' ', '') in selected_paper.lower().replace(' ', '') or selected_paper.lower().replace(' ', '') in paper_title.lower().replace(' ', ''): candidate_citation_papers = list(set(candidate_citation_papers) - set([paper_title])) topic_specific_citation_papers.append([paper_title, candidate_raw_ids[i]]) break citation_papers.append(topic_specific_citation_papers) # Remove empty lists citation_papers = [x for x in citation_papers if x != []] # Generate citation sentences for topic_idx in range(len(citation_papers)): topic_specific_nei_sentence = [] for paper_idx in range(len(citation_papers[topic_idx])): sentence = self.single_paper_sentence_test(usr_prompt, citation_papers[topic_idx][paper_idx][0], "") # Match \cite{...} sentence = re.sub(r'\\cite\{[^{}]+\}', "", sentence) topic_specific_nei_sentence.append(sentence) nei_sentence.append(topic_specific_nei_sentence) # Generate paragraphs paragraphs = [] references = [] # Store references for citation paper_citation_indicator = 1 # Indicator for citation paper for topic_idx in range(len(citation_papers)): datapoint = {'usr_prompt': usr_prompt, 'nei_title': citation_papers[topic_idx], 'nei_sentence': nei_sentence[topic_idx], 'topic': split_topics[topic_idx], 'paper_citation_indicator': paper_citation_indicator} prompt = self._generate_paragraph_prompt(datapoint) ans = self.get_llm_response(prompt, 'zeroshot') res = ans['content'] paragraphs.append(res) # Store referencess for ref_idx, paper in enumerate(citation_papers[topic_idx]): # Extract year and month from raw_id raw_id = re.sub(r'[a-zA-Z/]+', '', paper[1]) year = raw_id[:2] year = '19' + year if int(year) > 70 else '20' + year month = datetime.date(1900, int(raw_id[2:4]), 1).strftime('%B') references.append(f"[{paper_citation_indicator + ref_idx}] {paper[0]}, arXiv {raw_id}, {month} {year}") # Update paper_citation_indicator paper_citation_indicator = paper_citation_indicator + len(nei_sentence[topic_idx]) # Generate summary datapoint = {'usr_prompt': usr_prompt, 'paragraphs': paragraphs} prompt = self._generate_summary_prompt(datapoint) ans = self.get_llm_response(prompt, 'zeroshot') summary = ans['content'] # Append references to summary summary_with_references = summary + "\n\n### References\n" + "\n".join(references) return summary_with_references def gen_related_work(message, graph_path, adapter_path): litfm_instance = LitFM(graph_path, adapter_path) return litfm_instance.single_paper_related_work_generation(message)