Spaces:
Build error
Build error
Delete tasks
Browse files- tasks/abs_2_title.py +0 -23
- tasks/abs_completion.py +0 -29
- tasks/citation_sentence.py +0 -25
- tasks/gen_related_work.py +0 -430
- tasks/influential_papers.py +0 -41
- tasks/intro_2_abs.py +0 -28
- tasks/link_pred.py +0 -23
- tasks/paper_retrieval.py +0 -21
tasks/abs_2_title.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Generate a prompt for generating the title of a paper based on its abstract.
|
| 3 |
-
|
| 4 |
-
Args:
|
| 5 |
-
usr_input (str): A string containing the title and abstract of the paper in the format "Title: <title> Abstract: <abstract>".
|
| 6 |
-
template (dict): A dictionary containing the template for the prompt with a key "prompt_input".
|
| 7 |
-
|
| 8 |
-
Returns:
|
| 9 |
-
str: A formatted string with the instruction and abstract to be used as input for generating the title.
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
def abs_2_title(usr_input, template, has_predefined_template=False):
|
| 13 |
-
instruction = "Please generate the title of paper based on its abstract"
|
| 14 |
-
|
| 15 |
-
if has_predefined_template:
|
| 16 |
-
res = [
|
| 17 |
-
{"role": "system", "content": instruction},
|
| 18 |
-
{"role": "user", "content": usr_input},
|
| 19 |
-
]
|
| 20 |
-
else:
|
| 21 |
-
res = template["prompt_input"].format(instruction=instruction, input=usr_input)
|
| 22 |
-
|
| 23 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/abs_completion.py
DELETED
|
@@ -1,29 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Generates a formatted prompt for completing the abstract of a paper.
|
| 3 |
-
|
| 4 |
-
Args:
|
| 5 |
-
usr_input (str): The user input containing the title and part of the abstract.
|
| 6 |
-
Expected format:
|
| 7 |
-
"Title: <title>\nAbstract: <abstract>"
|
| 8 |
-
template (dict): A dictionary containing the template for the prompt.
|
| 9 |
-
Expected format:
|
| 10 |
-
{"prompt_input": "<template_string>"}
|
| 11 |
-
The template string should contain placeholders for
|
| 12 |
-
'instruction' and 'input'.
|
| 13 |
-
|
| 14 |
-
Returns:
|
| 15 |
-
str: A formatted string with the instruction and the input embedded in the template.
|
| 16 |
-
"""
|
| 17 |
-
|
| 18 |
-
def abs_completion(usr_input, template, has_predefined_template=False):
|
| 19 |
-
instruction = "Please complete the abstract of a paper."
|
| 20 |
-
|
| 21 |
-
if has_predefined_template:
|
| 22 |
-
res = [
|
| 23 |
-
{"role": "system", "content": instruction},
|
| 24 |
-
{"role": "user", "content": usr_input},
|
| 25 |
-
]
|
| 26 |
-
else:
|
| 27 |
-
res = template["prompt_input"].format(instruction=instruction, input=usr_input)
|
| 28 |
-
|
| 29 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/citation_sentence.py
DELETED
|
@@ -1,25 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Generates a citation sentence based on the titles and abstracts of two papers.
|
| 3 |
-
|
| 4 |
-
Args:
|
| 5 |
-
usr_input (str): A string containing the titles and abstracts of Paper A and Paper B.
|
| 6 |
-
The format should be:
|
| 7 |
-
"Title A: <title of paper A>\nAbstract A: <abstract of paper A>\nTitle B: <title of paper B>\nAbstract B: <abstract of paper B>"
|
| 8 |
-
template (dict): A dictionary containing a template for the prompt input. The key "prompt_input" should map to a string with placeholders for the instruction and input.
|
| 9 |
-
|
| 10 |
-
Returns:
|
| 11 |
-
str: A formatted string that combines the instruction and the prompt input with the provided titles and abstracts.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
def citation_sentence(usr_input, template, has_predefined_template=False):
|
| 15 |
-
instruction = "Please generate the citation sentence of how Paper A cites paper B in its related work section. \n"
|
| 16 |
-
|
| 17 |
-
if has_predefined_template:
|
| 18 |
-
res = [
|
| 19 |
-
{"role": "system", "content": instruction},
|
| 20 |
-
{"role": "user", "content": usr_input},
|
| 21 |
-
]
|
| 22 |
-
else:
|
| 23 |
-
res = template["prompt_input"].format(instruction=instruction, input=usr_input)
|
| 24 |
-
|
| 25 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/gen_related_work.py
DELETED
|
@@ -1,430 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Generates the related work section for a given paper.
|
| 3 |
-
|
| 4 |
-
The input
|
| 5 |
-
- The input prompt is a string that contains the information of the paper for which the related work section needs to be generated.
|
| 6 |
-
- The input prompt should be in the following format:
|
| 7 |
-
Title of Paper: <title of the paper>
|
| 8 |
-
|
| 9 |
-
Abstract of Paper: <abstract of the paper>
|
| 10 |
-
The output
|
| 11 |
-
- The output is a string that contains the related work section for the given paper.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
import torch
|
| 15 |
-
import json
|
| 16 |
-
import networkx as nx
|
| 17 |
-
import numpy as np
|
| 18 |
-
from tqdm import tqdm
|
| 19 |
-
from peft import PeftModel
|
| 20 |
-
from transformers import (AutoModel, AutoTokenizer, AutoModelForCausalLM, pipeline)
|
| 21 |
-
from tqdm import tqdm
|
| 22 |
-
import re
|
| 23 |
-
import pandas as pd
|
| 24 |
-
import os
|
| 25 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
| 26 |
-
from utils.utils import read_yaml_file
|
| 27 |
-
import datetime
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
class LitFM():
|
| 31 |
-
def __init__(self, graph_path, adapter_path):
|
| 32 |
-
self.graph_name = graph_path.split('.')[0].split('/')[-1] if '/' in graph_path else graph_path.split('.')[0]
|
| 33 |
-
self.batch_size = 32
|
| 34 |
-
self.neigh_num = 4
|
| 35 |
-
|
| 36 |
-
config = read_yaml_file('configs/config.yaml')
|
| 37 |
-
retrieval_graph_path = graph_path
|
| 38 |
-
|
| 39 |
-
self.pretrained_model = config['retriever']['embedder']
|
| 40 |
-
|
| 41 |
-
# define generation model
|
| 42 |
-
model_path = config['inference']["base_model"]
|
| 43 |
-
self.generation_tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 44 |
-
self.generation_tokenizer.model_max_length = 2048
|
| 45 |
-
if self.generation_tokenizer.pad_token is None:
|
| 46 |
-
self.generation_tokenizer.pad_token = self.generation_tokenizer.eos_token
|
| 47 |
-
self.generation_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, device_map="auto")
|
| 48 |
-
self.generation_model = PeftModel.from_pretrained(self.generation_model, adapter_path, adapter_name="instruction", torch_dtype=torch.float16)
|
| 49 |
-
self.model_pipeline = None
|
| 50 |
-
if self.generation_tokenizer.chat_template is not None:
|
| 51 |
-
self.model_pipeline = pipeline(
|
| 52 |
-
"text-generation",
|
| 53 |
-
model=model_path,
|
| 54 |
-
model_kwargs={"torch_dtype": torch.bfloat16},
|
| 55 |
-
device_map="auto",
|
| 56 |
-
)
|
| 57 |
-
|
| 58 |
-
# define instruction models
|
| 59 |
-
self.instruction_pipe = pipeline(
|
| 60 |
-
"text-generation",
|
| 61 |
-
model=config["inference"]["gen_related_work_instruct_model"],
|
| 62 |
-
model_kwargs={"torch_dtype": torch.bfloat16},
|
| 63 |
-
device_map="auto",
|
| 64 |
-
)
|
| 65 |
-
|
| 66 |
-
# load graph data for retrieval
|
| 67 |
-
def translate_graph(graph):
|
| 68 |
-
all_nodes = list(graph.nodes())
|
| 69 |
-
raw_id_2_id_dict = {}
|
| 70 |
-
id_2_raw_id_dict = {}
|
| 71 |
-
|
| 72 |
-
num = 0
|
| 73 |
-
for node in all_nodes:
|
| 74 |
-
raw_id_2_id_dict[node] = num
|
| 75 |
-
id_2_raw_id_dict[num] = node
|
| 76 |
-
num += 1
|
| 77 |
-
|
| 78 |
-
return raw_id_2_id_dict, id_2_raw_id_dict
|
| 79 |
-
|
| 80 |
-
whole_graph_data_raw = nx.read_gexf(retrieval_graph_path, node_type=None, relabel=False, version='1.2draft')
|
| 81 |
-
self.whole_graph_raw_id_2_id_dict, self.whole_graph_id_2_raw_id_dict = translate_graph(whole_graph_data_raw)
|
| 82 |
-
|
| 83 |
-
self.whole_graph_id_2_title_abs = dict()
|
| 84 |
-
for paper_id in whole_graph_data_raw.nodes():
|
| 85 |
-
title = whole_graph_data_raw.nodes()[paper_id]['title']
|
| 86 |
-
abstract = whole_graph_data_raw.nodes()[paper_id]['abstract']
|
| 87 |
-
self.whole_graph_id_2_title_abs[self.whole_graph_raw_id_2_id_dict[paper_id]] = [title, abstract]
|
| 88 |
-
|
| 89 |
-
# define prompt template
|
| 90 |
-
template_file_path = 'configs/alpaca.json'
|
| 91 |
-
with open(template_file_path) as fp:
|
| 92 |
-
self.template = json.load(fp)
|
| 93 |
-
self.human_instruction = ['### Input:', '### Response:']
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def _generate_retrieval_prompt(self, data_point: dict):
|
| 97 |
-
instruction = "Please select the paper that is more likely to be cited by the paper from the list of candidate papers. Your answer MUST be **only the exact title** of the selected paper without generating ANY other text or section. Your answer MUST belong to the list of candidate papers.\n"
|
| 98 |
-
prompt_input = ""
|
| 99 |
-
prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
|
| 100 |
-
prompt_input = prompt_input + "candidate papers: " + "\n"
|
| 101 |
-
for i in range(len(data_point['nei_titles'])):
|
| 102 |
-
prompt_input = prompt_input + str(i) + '. ' + data_point['nei_titles'][i] + "\n"
|
| 103 |
-
|
| 104 |
-
if self.model_pipeline is not None:
|
| 105 |
-
res = [
|
| 106 |
-
{"role": "system", "content": instruction},
|
| 107 |
-
{"role": "user", "content": prompt_input},
|
| 108 |
-
]
|
| 109 |
-
else:
|
| 110 |
-
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
|
| 111 |
-
|
| 112 |
-
return res
|
| 113 |
-
|
| 114 |
-
def _generate_sentence_prompt(self, data_point):
|
| 115 |
-
instruction = "Please generate the citation sentence of how the Paper cites paper B in its related work section."
|
| 116 |
-
|
| 117 |
-
prompt_input = ""
|
| 118 |
-
prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
|
| 119 |
-
prompt_input = prompt_input + "Title of Paper B: " + (data_point['t_title'] if data_point['t_title'] != None else 'Unknown') + "\n"
|
| 120 |
-
prompt_input = prompt_input + "Abstract of Paper B: " + (data_point['t_abs'] if data_point['t_abs'] != None else 'Unknown') + "\n"
|
| 121 |
-
|
| 122 |
-
if self.model_pipeline is not None:
|
| 123 |
-
res = [
|
| 124 |
-
{"role": "system", "content": instruction},
|
| 125 |
-
{"role": "user", "content": prompt_input},
|
| 126 |
-
]
|
| 127 |
-
else:
|
| 128 |
-
res = self.template["prompt_input"].format(instruction=instruction, input=prompt_input)
|
| 129 |
-
|
| 130 |
-
return res
|
| 131 |
-
|
| 132 |
-
def _generate_topic_prompt(self, data_point):
|
| 133 |
-
prompt_input = ""
|
| 134 |
-
prompt_input = prompt_input + "Here are the information of the paper: \n"
|
| 135 |
-
prompt_input = prompt_input + data_point['usr_prompt'] + '\n'
|
| 136 |
-
prompt_input = prompt_input + "Directlty give me the topics you select.\n"
|
| 137 |
-
|
| 138 |
-
res = [
|
| 139 |
-
{"role": "system", "content": "I need to write the related work section for this paper. Could you suggest three most relevant topics to discuss in the related work section? Your answer should be strictly one topic after the other line by line with nothing else being generated and no further explanation/information.\n"},
|
| 140 |
-
{"role": "user", "content": prompt_input},
|
| 141 |
-
]
|
| 142 |
-
|
| 143 |
-
return res
|
| 144 |
-
|
| 145 |
-
def _generate_paragraph_prompt(self, data_point):
|
| 146 |
-
prompt_input = ""
|
| 147 |
-
prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
|
| 148 |
-
prompt_input = prompt_input + "Topic of this paragraph: " + data_point['topic'] + "\n"
|
| 149 |
-
prompt_input = prompt_input + "Papers that should be cited in paragraph: \n"
|
| 150 |
-
|
| 151 |
-
i = data_point['paper_citation_indicator']
|
| 152 |
-
for paper_idx in range(len(data_point['nei_title'])):
|
| 153 |
-
prompt_input = prompt_input + "[" + str(i) + "]. " + data_point['nei_title'][paper_idx][0] + '.' + " Citation sentence of this paper in the paragraph: " + data_point['nei_sentence'][paper_idx] + '\n'
|
| 154 |
-
i += 1
|
| 155 |
-
|
| 156 |
-
prompt_input = prompt_input + "All the above cited papers should be included and each cited paper should be indicated with its index number. Note that you should not include the title of any paper\n"
|
| 157 |
-
|
| 158 |
-
res = [
|
| 159 |
-
{"role": "system", "content": "Please write a paragraph that review the research relationships between this paper and other cited papers.\n"},
|
| 160 |
-
{"role": "user", "content": prompt_input},
|
| 161 |
-
]
|
| 162 |
-
|
| 163 |
-
return res
|
| 164 |
-
|
| 165 |
-
def _generate_summary_prompt(self, data_point):
|
| 166 |
-
prompt_input = ""
|
| 167 |
-
prompt_input = prompt_input + data_point['usr_prompt'] + "\n"
|
| 168 |
-
prompt_input = prompt_input + "Paragraphs that should be combined: " + "\n"
|
| 169 |
-
|
| 170 |
-
i = 1
|
| 171 |
-
for para in data_point['paragraphs']:
|
| 172 |
-
prompt_input = prompt_input + " Paragraph " + str(i) + ": " + para + '\n'
|
| 173 |
-
i += 1
|
| 174 |
-
|
| 175 |
-
res = [
|
| 176 |
-
{"role": "system", "content": "Please combine the following paragraphs in a cohenrent way that also keeps the citations and make the flow between paragraphs more smoothly\nAdd a sentence at the beginning of each paragraph to clarify its connection to the previous ones. Do not include any other surrounding text and not add a references list at all\n"},
|
| 177 |
-
{"role": "user", "content": prompt_input},
|
| 178 |
-
]
|
| 179 |
-
|
| 180 |
-
return res
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
@staticmethod
|
| 184 |
-
def generate_text(prompt, tokenizer, model, temperature, top_p, repetition_penalty, max_new_tokens):
|
| 185 |
-
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
| 186 |
-
|
| 187 |
-
with torch.no_grad():
|
| 188 |
-
output = model.generate(
|
| 189 |
-
**inputs,
|
| 190 |
-
do_sample=True,
|
| 191 |
-
temperature=temperature,
|
| 192 |
-
top_p=top_p,
|
| 193 |
-
repetition_penalty=repetition_penalty,
|
| 194 |
-
max_new_tokens=max_new_tokens,
|
| 195 |
-
pad_token_id=tokenizer.pad_token_id,
|
| 196 |
-
eos_token_id=tokenizer.eos_token_id,
|
| 197 |
-
use_cache=True,
|
| 198 |
-
)
|
| 199 |
-
output_text = tokenizer.decode(output[0], skip_special_tokens=True)
|
| 200 |
-
return output_text
|
| 201 |
-
|
| 202 |
-
def get_llm_response(self, prompt, model_type):
|
| 203 |
-
self.generation_model.set_adapter('instruction')
|
| 204 |
-
if model_type == 'zeroshot':
|
| 205 |
-
raw_output = self.instruction_pipe(
|
| 206 |
-
prompt,
|
| 207 |
-
max_new_tokens=8096,
|
| 208 |
-
temperature=0.9,
|
| 209 |
-
top_p=0.95,
|
| 210 |
-
repetition_penalty=1.15,
|
| 211 |
-
)[0]['generated_text'][-1]
|
| 212 |
-
|
| 213 |
-
if model_type == 'zeroshot_short':
|
| 214 |
-
raw_output = self.instruction_pipe(
|
| 215 |
-
prompt,
|
| 216 |
-
max_new_tokens=256,
|
| 217 |
-
temperature=0.9,
|
| 218 |
-
top_p=0.95,
|
| 219 |
-
repetition_penalty=1.15,
|
| 220 |
-
)[0]['generated_text'][-1]
|
| 221 |
-
|
| 222 |
-
if model_type == 'instruction':
|
| 223 |
-
self.generation_model.set_adapter('instruction')
|
| 224 |
-
if self.model_pipeline is not None:
|
| 225 |
-
raw_output = self.model_pipeline(
|
| 226 |
-
prompt,
|
| 227 |
-
temperature=0.9,
|
| 228 |
-
top_p=0.95,
|
| 229 |
-
repetition_penalty=1.15,
|
| 230 |
-
)[0]['generated_text'][-1]
|
| 231 |
-
else:
|
| 232 |
-
raw_output = self.generate_text(
|
| 233 |
-
prompt,
|
| 234 |
-
self.generation_tokenizer,
|
| 235 |
-
self.generation_model,
|
| 236 |
-
temperature=0.9,
|
| 237 |
-
top_p=0.95,
|
| 238 |
-
repetition_penalty=1.15,
|
| 239 |
-
max_new_tokens=256,
|
| 240 |
-
)
|
| 241 |
-
|
| 242 |
-
return raw_output
|
| 243 |
-
|
| 244 |
-
def single_paper_sentence_test(self, usr_prompt, t_title, t_abs):
|
| 245 |
-
datapoint = {'usr_prompt':usr_prompt, 't_title':t_title, 't_abs':t_abs}
|
| 246 |
-
prompt = self._generate_sentence_prompt(datapoint)
|
| 247 |
-
ans = self.get_llm_response(prompt, 'instruction')
|
| 248 |
-
res = ans.strip().split(self.human_instruction[1])[-1]
|
| 249 |
-
return res
|
| 250 |
-
|
| 251 |
-
def single_paper_retrieval_test(self, usr_prompt, candidates):
|
| 252 |
-
datapoint = {'usr_prompt':usr_prompt, 'nei_titles':list(candidates), 't_title': ''}
|
| 253 |
-
prompt = self._generate_retrieval_prompt(datapoint)
|
| 254 |
-
ans = self.get_llm_response(prompt, 'instruction')
|
| 255 |
-
res = ans.strip().split(self.human_instruction[1])[-1]
|
| 256 |
-
return res
|
| 257 |
-
|
| 258 |
-
def single_paper_topic_test(self, usr_prompt):
|
| 259 |
-
datapoint = {'usr_prompt': usr_prompt}
|
| 260 |
-
prompt = self._generate_topic_prompt(datapoint)
|
| 261 |
-
ans = self.get_llm_response(prompt, 'zeroshot_short')
|
| 262 |
-
res = ans['content']
|
| 263 |
-
res = res.replace('\n\n', '\n')
|
| 264 |
-
return res
|
| 265 |
-
|
| 266 |
-
def retrieval_for_one_query(self, id_2_title_abs, prompt):
|
| 267 |
-
if os.path.exists(f'datasets/{self.graph_name}_embeddings.parquet'):
|
| 268 |
-
all_query_embs = torch.tensor(np.array(pd.read_parquet(f'datasets/{self.graph_name}_embeddings.parquet')))
|
| 269 |
-
else:
|
| 270 |
-
tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-large-en-v1.5")
|
| 271 |
-
model = AutoModel.from_pretrained("BAAI/bge-large-en-v1.5").to(device='cuda', dtype=torch.float16)
|
| 272 |
-
model.eval()
|
| 273 |
-
|
| 274 |
-
paper_list = list(id_2_title_abs.keys())
|
| 275 |
-
|
| 276 |
-
all_query_embs = torch.zeros(len(paper_list), 1024)
|
| 277 |
-
i = 0
|
| 278 |
-
batch_size = 200
|
| 279 |
-
candidate_emb_list = []
|
| 280 |
-
pbar = tqdm(total=len(paper_list))
|
| 281 |
-
while i < len(paper_list):
|
| 282 |
-
paper_batch = paper_list[i:i+batch_size]
|
| 283 |
-
paper_text_batch = []
|
| 284 |
-
for paper_id in paper_batch:
|
| 285 |
-
prompt = id_2_title_abs[paper_id][0] + id_2_title_abs[paper_id][1]
|
| 286 |
-
paper_text_batch.append(prompt)
|
| 287 |
-
inputs = tokenizer(paper_text_batch, return_tensors='pt', padding=True, truncation=True)
|
| 288 |
-
|
| 289 |
-
with torch.no_grad():
|
| 290 |
-
outputs = model(**inputs.to('cuda'))
|
| 291 |
-
candidate_embeddings = outputs.last_hidden_state[:, 0, :].cpu()
|
| 292 |
-
candidate_embeddings = candidate_embeddings.reshape(-1, 1024)
|
| 293 |
-
candidate_emb_list.append(candidate_embeddings)
|
| 294 |
-
|
| 295 |
-
i += len(candidate_embeddings)
|
| 296 |
-
pbar.update(len(candidate_embeddings))
|
| 297 |
-
|
| 298 |
-
all_query_embs = torch.cat(candidate_emb_list, 0)
|
| 299 |
-
pd.DataFrame(all_query_embs.numpy()).to_parquet(f'datasets/{self.graph_name}_embeddings.parquet')
|
| 300 |
-
|
| 301 |
-
# get the embeddings of the prompt
|
| 302 |
-
pretrained_model_name = self.pretrained_model
|
| 303 |
-
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
|
| 304 |
-
LLM_model = AutoModel.from_pretrained(pretrained_model_name).cuda()
|
| 305 |
-
LLM_model.eval()
|
| 306 |
-
|
| 307 |
-
encoded_input = tokenizer([prompt], padding = True, truncation=True, max_length=512 , return_tensors='pt')
|
| 308 |
-
with torch.no_grad():
|
| 309 |
-
output = LLM_model(**encoded_input.to('cuda'), output_hidden_states=True).hidden_states[-1]
|
| 310 |
-
sentence_embedding = output[:, 0, :]
|
| 311 |
-
|
| 312 |
-
tmp_scores = cosine_similarity(sentence_embedding.to("cpu"), all_query_embs.to("cpu"))[0]
|
| 313 |
-
_, idxs = torch.sort(torch.tensor(tmp_scores), descending=True)
|
| 314 |
-
top_10 = [int(k) for k in idxs[:10]]
|
| 315 |
-
|
| 316 |
-
return [id_2_title_abs[i][0] for i in top_10], [self.whole_graph_id_2_raw_id_dict[i] for i in top_10]
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
def single_paper_related_work_generation(self, usr_prompt):
|
| 320 |
-
citation_papers = []
|
| 321 |
-
nei_sentence = []
|
| 322 |
-
|
| 323 |
-
# Get topics
|
| 324 |
-
retrieval_query = self.single_paper_topic_test(usr_prompt)
|
| 325 |
-
|
| 326 |
-
# Split topics
|
| 327 |
-
topic_num = 3
|
| 328 |
-
try:
|
| 329 |
-
split_topics = retrieval_query.strip().split('\n')
|
| 330 |
-
if split_topics[0] == '':
|
| 331 |
-
split_topics = split_topics[1:]
|
| 332 |
-
split_topics = split_topics[:topic_num]
|
| 333 |
-
except:
|
| 334 |
-
split_topics = retrieval_query.strip().split(':')
|
| 335 |
-
split_topics = split_topics.strip().split(';')
|
| 336 |
-
split_topics = split_topics[:topic_num]
|
| 337 |
-
if len(split_topics) > topic_num:
|
| 338 |
-
return ["too many topics", split_topics]
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
# Get top-5 papers for each topic
|
| 342 |
-
for retrieval_query in split_topics:
|
| 343 |
-
# retrieve papers
|
| 344 |
-
candidate_citation_papers, candidate_raw_ids = self.retrieval_for_one_query(self.whole_graph_id_2_title_abs, retrieval_query)
|
| 345 |
-
topic_specific_citation_papers = []
|
| 346 |
-
# select top-5 papers
|
| 347 |
-
for _ in range(5):
|
| 348 |
-
# picking most likely to be cited paper
|
| 349 |
-
selected_paper = self.single_paper_retrieval_test(usr_prompt, candidate_citation_papers).replace(' \n','').replace('\n','')
|
| 350 |
-
|
| 351 |
-
words = selected_paper.strip().split(' ')
|
| 352 |
-
index = -1
|
| 353 |
-
for w in words:
|
| 354 |
-
try:
|
| 355 |
-
index = int(w)
|
| 356 |
-
except:
|
| 357 |
-
pass
|
| 358 |
-
|
| 359 |
-
if index != -1 and index < len(candidate_citation_papers):
|
| 360 |
-
paper_title = candidate_citation_papers[index]
|
| 361 |
-
candidate_citation_papers = list(set(candidate_citation_papers) - set([paper_title]))
|
| 362 |
-
topic_specific_citation_papers.append([paper_title, candidate_raw_ids[index]])
|
| 363 |
-
else:
|
| 364 |
-
for i, paper_title in enumerate(list(candidate_citation_papers)):
|
| 365 |
-
if paper_title.lower().replace(' ', '') in selected_paper.lower().replace(' ', '') or selected_paper.lower().replace(' ', '') in paper_title.lower().replace(' ', ''):
|
| 366 |
-
candidate_citation_papers = list(set(candidate_citation_papers) - set([paper_title]))
|
| 367 |
-
topic_specific_citation_papers.append([paper_title, candidate_raw_ids[i]])
|
| 368 |
-
break
|
| 369 |
-
citation_papers.append(topic_specific_citation_papers)
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
# Remove empty lists
|
| 373 |
-
citation_papers = [x for x in citation_papers if x != []]
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
# Generate citation sentences
|
| 377 |
-
for topic_idx in range(len(citation_papers)):
|
| 378 |
-
topic_specific_nei_sentence = []
|
| 379 |
-
for paper_idx in range(len(citation_papers[topic_idx])):
|
| 380 |
-
sentence = self.single_paper_sentence_test(usr_prompt, citation_papers[topic_idx][paper_idx][0], "")
|
| 381 |
-
# Match \cite{...}
|
| 382 |
-
sentence = re.sub(r'\\cite\{[^{}]+\}', "", sentence)
|
| 383 |
-
topic_specific_nei_sentence.append(sentence)
|
| 384 |
-
nei_sentence.append(topic_specific_nei_sentence)
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
# Generate paragraphs
|
| 388 |
-
paragraphs = []
|
| 389 |
-
references = [] # Store references for citation
|
| 390 |
-
paper_citation_indicator = 1 # Indicator for citation paper
|
| 391 |
-
for topic_idx in range(len(citation_papers)):
|
| 392 |
-
datapoint = {'usr_prompt': usr_prompt,
|
| 393 |
-
'nei_title': citation_papers[topic_idx],
|
| 394 |
-
'nei_sentence': nei_sentence[topic_idx],
|
| 395 |
-
'topic': split_topics[topic_idx],
|
| 396 |
-
'paper_citation_indicator': paper_citation_indicator}
|
| 397 |
-
|
| 398 |
-
prompt = self._generate_paragraph_prompt(datapoint)
|
| 399 |
-
ans = self.get_llm_response(prompt, 'zeroshot')
|
| 400 |
-
res = ans['content']
|
| 401 |
-
paragraphs.append(res)
|
| 402 |
-
|
| 403 |
-
# Store referencess
|
| 404 |
-
for ref_idx, paper in enumerate(citation_papers[topic_idx]):
|
| 405 |
-
# Extract year and month from raw_id
|
| 406 |
-
raw_id = re.sub(r'[a-zA-Z/]+', '', paper[1])
|
| 407 |
-
year = raw_id[:2]
|
| 408 |
-
year = '19' + year if int(year) > 70 else '20' + year
|
| 409 |
-
month = datetime.date(1900, int(raw_id[2:4]), 1).strftime('%B')
|
| 410 |
-
|
| 411 |
-
references.append(f"[{paper_citation_indicator + ref_idx}] {paper[0]}, arXiv {raw_id}, {month} {year}")
|
| 412 |
-
# Update paper_citation_indicator
|
| 413 |
-
paper_citation_indicator = paper_citation_indicator + len(nei_sentence[topic_idx])
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
# Generate summary
|
| 417 |
-
datapoint = {'usr_prompt': usr_prompt, 'paragraphs': paragraphs}
|
| 418 |
-
prompt = self._generate_summary_prompt(datapoint)
|
| 419 |
-
ans = self.get_llm_response(prompt, 'zeroshot')
|
| 420 |
-
summary = ans['content']
|
| 421 |
-
|
| 422 |
-
# Append references to summary
|
| 423 |
-
summary_with_references = summary + "\n\n### References\n" + "\n".join(references)
|
| 424 |
-
|
| 425 |
-
return summary_with_references
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
def gen_related_work(message, graph_path, adapter_path):
|
| 429 |
-
litfm_instance = LitFM(graph_path, adapter_path)
|
| 430 |
-
return litfm_instance.single_paper_related_work_generation(message)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/influential_papers.py
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Influential Papers Task
|
| 3 |
-
|
| 4 |
-
This module provides functionality to identify the most influential papers in a citation graph.
|
| 5 |
-
|
| 6 |
-
Functions:
|
| 7 |
-
influential_papers(K, graph):
|
| 8 |
-
Given an integer K and a citation graph, returns the K most influential papers based on the number of citations.
|
| 9 |
-
The function returns the title and abstract of each of the K most influential papers in a formatted string.
|
| 10 |
-
|
| 11 |
-
Usage:
|
| 12 |
-
The script reads configuration from a YAML file, loads a citation graph from a GEXF file, and prints the K most influential papers.
|
| 13 |
-
"""
|
| 14 |
-
|
| 15 |
-
import datetime
|
| 16 |
-
import re
|
| 17 |
-
|
| 18 |
-
def influential_papers(message, graph):
|
| 19 |
-
# Get integer number from message
|
| 20 |
-
K = int(re.search(r'\d+', message).group())
|
| 21 |
-
|
| 22 |
-
in_degree = dict(graph.in_degree())
|
| 23 |
-
sorted_in_degree = sorted(in_degree.items(), key=lambda x: x[1], reverse=True)
|
| 24 |
-
|
| 25 |
-
most_cited_papers = []
|
| 26 |
-
for i in range(K):
|
| 27 |
-
node = sorted_in_degree[i]
|
| 28 |
-
paper = graph.nodes[node[0]]
|
| 29 |
-
most_cited_papers.append(paper)
|
| 30 |
-
|
| 31 |
-
resp = "Here are the most influential papers:\n"
|
| 32 |
-
for i, paper in enumerate(most_cited_papers):
|
| 33 |
-
full_paper_id = paper['label']
|
| 34 |
-
paper_id = re.sub(r'[a-zA-Z/]+', '', full_paper_id)
|
| 35 |
-
year = paper_id[:2]
|
| 36 |
-
year = '19' + year if int(year) > 70 else '20' + year
|
| 37 |
-
month = datetime.date(1900, int(paper_id[2:4]), 1).strftime('%B')
|
| 38 |
-
|
| 39 |
-
resp += f"{i+1}. Title: {paper['title']}, arXiv {full_paper_id}, {month} {year} \nAbstract: {paper['abstract']}\n"
|
| 40 |
-
|
| 41 |
-
return resp
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/intro_2_abs.py
DELETED
|
@@ -1,28 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Generate the abstract of a paper based on its introduction section.
|
| 3 |
-
|
| 4 |
-
Args:
|
| 5 |
-
usr_prompt (str): The user-provided prompt containing the introduction section of the paper.
|
| 6 |
-
template (dict): A dictionary containing the template for generating the abstract.
|
| 7 |
-
context_window (int): The maximum length of the context window for the prompt input.
|
| 8 |
-
|
| 9 |
-
Returns:
|
| 10 |
-
str: The generated abstract based on the introduction section.
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def intro_2_abs(usr_prompt, template, context_window, has_predefined_template=False):
|
| 15 |
-
instruction = "Please generate the abstract of paper based on its introduction section."
|
| 16 |
-
|
| 17 |
-
# Reduce it to make it fit
|
| 18 |
-
prompt_input = usr_prompt[:int(context_window*2)]
|
| 19 |
-
|
| 20 |
-
if has_predefined_template:
|
| 21 |
-
res = [
|
| 22 |
-
{"role": "system", "content": instruction},
|
| 23 |
-
{"role": "user", "content": prompt_input},
|
| 24 |
-
]
|
| 25 |
-
else:
|
| 26 |
-
res = template["prompt_input"].format(instruction=instruction, input=prompt_input)
|
| 27 |
-
|
| 28 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/link_pred.py
DELETED
|
@@ -1,23 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Determine if paper A will cite paper B.
|
| 3 |
-
|
| 4 |
-
Args:
|
| 5 |
-
usr_input (str): The user-provided input containing the titles and abstracts of papers A and B.
|
| 6 |
-
template (dict): A dictionary containing the template for generating the link prediction task.
|
| 7 |
-
|
| 8 |
-
Returns:
|
| 9 |
-
str: The generated link prediction task based on the user input.
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
def link_pred(usr_input, template, has_predefined_template=False):
|
| 13 |
-
instruction = "Determine if paper A will cite paper B."
|
| 14 |
-
|
| 15 |
-
if has_predefined_template:
|
| 16 |
-
res = [
|
| 17 |
-
{"role": "system", "content": instruction},
|
| 18 |
-
{"role": "user", "content": usr_input},
|
| 19 |
-
]
|
| 20 |
-
else:
|
| 21 |
-
res = template["prompt_input"].format(instruction=instruction, input=usr_input)
|
| 22 |
-
|
| 23 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tasks/paper_retrieval.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Retrieves the most likely paper to be cited by Paper A from a list of candidate papers based on user input.
|
| 3 |
-
Args:
|
| 4 |
-
usr_input (str): A string containing the title and abstract of Paper A followed by the titles and abstracts of candidate papers.
|
| 5 |
-
template (dict): A dictionary containing a template for formatting the prompt input.
|
| 6 |
-
Returns:
|
| 7 |
-
str: A string containing the prompt input for the user.
|
| 8 |
-
"""
|
| 9 |
-
|
| 10 |
-
def paper_retrieval(usr_input, template, has_predefined_template=False):
|
| 11 |
-
instruction = "Please select the paper that is more likely to be cited by paper A from candidate papers."
|
| 12 |
-
|
| 13 |
-
if has_predefined_template:
|
| 14 |
-
res = [
|
| 15 |
-
{"role": "system", "content": instruction},
|
| 16 |
-
{"role": "user", "content": usr_input},
|
| 17 |
-
]
|
| 18 |
-
else:
|
| 19 |
-
res = template["prompt_input"].format(instruction=instruction, input=usr_input)
|
| 20 |
-
|
| 21 |
-
return res
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|