from train import * from utils.utils import * from utils.graph_utils import * from utils.gradio_utils import * from retriever.retriever import retriever from tasks.abs_2_title import abs_2_title from tasks.abs_completion import abs_completion from tasks.citation_sentence import citation_sentence from tasks.intro_2_abs import intro_2_abs from tasks.link_pred import link_pred from tasks.paper_retrieval import paper_retrieval from tasks.influential_papers import influential_papers from tasks.gen_related_work import gen_related_work import random import json import os import re import networkx as nx import tarfile import gzip import time import urllib.request from tqdm import tqdm from colorama import Fore import wandb import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList, TextIteratorStreamer, pipeline from threading import Thread import signal import gzip import time import torch from peft.peft_model import PeftModel from datasets import load_dataset from huggingface_hub import login import spaces os.environ["CUDA_VISIBLE_DEVICES"]="0" # Load token from Hugging Face Secrets HF_TOKEN = os.getenv("access_token") # Login using the token if HF_TOKEN: login(HF_TOKEN) # Function to determine the chatbot's first message based on user choices def setup(download_option, train_option): download_papers.value = (download_option == "Download Paper") train_model.value = (train_option == "Train") if download_option == "Download Paper": initial_message = [{"role": "assistant", "content": "Hello, what domain are you interested in?"}] elif download_option != "Download Paper" and train_option == "Train": initial_message = [{"role": "assistant", "content": "What domain is your graph about?"}] else: initial_message = [{"role": "assistant", "content": "Please provide your task prompt."}] return gr.update(visible=False), gr.update(visible=True), f"Download: {download_option}\nTrain: {train_option}", initial_message # Function to toggle the selected task based on user input def update_button_styles(selected_task): """Update button styles based on selection.""" return [gr.update(variant="primary" if selected_task == prompt else "secondary") for prompt in task_list] # Fetch and store arXiv source files def fetch_arxiv_papers(papers_to_download): # Download the arXiv metadata file if it doesn't exist dataset = 'datasets/arxiv-metadata-oai-snapshot.json' data = [] if not os.path.exists(dataset): os.system("wget https://huggingface.co/spaces/ddiddu/simsearch/resolve/main/arxiv-metadata-oai-snapshot.json -P ./datasets") with open(dataset, 'r') as f: for line in f: data.append(json.loads(line)) papers = [d for d in data] paper_ids = [d['id'] for d in data] paper_titles = [ ( re.sub(r' +', ' ', re.sub(r'[\n]+', ' ', paper['title'])) .replace("\\emph", "") .replace("\\emp", "") .replace("\\em", "") .replace(",", "") .replace("{", "") .replace("}", "") .strip(".") .strip() .strip(".") .lower() ) for paper in papers ] paper_dict = { k:v for k,v in zip(paper_titles, paper_ids) } total_papers = len(papers_to_download) download_progress_bar=gr.Progress() llm_resp = [] results = { "Number of papers": 0, "Number of latex papers": 0, "Number of bib files": 0, "Number of bbl files": 0, "Number of inline files": 0, "Number of introductions found": 0, "Number of related works found": 0, "Number of succesful finding of extracts": 0 } num_papers, num_edges, t, iter_ind = 0, 0, 0, 0 graph = {} arxiv_rate_lim = config['data_downloading']['processing']['arxiv_rate_limit'] for paper_name in tqdm(papers_to_download): results["Number of papers"] += 1 print( Fore.BLUE + "Number of papers processed: {} \n Number of edges found: {} \n Time of previous iter: {} \n Now processing paper: {} \n\n" .format(num_papers, num_edges, time.time()-t, paper_name) + Fore.RESET ) t = time.time() num_papers += 1 # Prepare the paper name for downloading and saving paper_name_download = paper_name if re.search(r'[a-zA-Z]', paper_name) is not None: paper_name = "".join(paper_name.split('/')) tar_file_path = save_zip_directory + paper_name + '.tar.gz' # Attempt to download the paper source files from arXiv try: # Track start time for download t1 = time.time() urllib.request.urlretrieve( "https://arxiv.org/src/"+paper_name_download, tar_file_path) except Exception as e: print("Couldn't download paper {}".format(paper_name)) # Skip to the next paper if download fails continue # Define the directory where the paper will be extracted extracted_dir = save_directory + paper_name + '/' isExist = os.path.exists(extracted_dir) if not isExist: os.makedirs(extracted_dir) # Attempt to extract the tar.gz archive try: tar = tarfile.open(tar_file_path) tar.extractall(extracted_dir) tar.close() except Exception as e: # If tar extraction fails, attempt to read and extract using gzip try: with gzip.open(tar_file_path, 'rb') as f: file_content = f.read() # Save the extracted content as a .tex file with open(extracted_dir+paper_name+'.tex', 'w') as f: f.write(file_content.decode()) except Exception as e: print("Could not extract paper id: {}".format(paper_name)) # Skip this paper if extraction fails continue try: # Perform initial cleaning and get the main TeX file initial_clean(extracted_dir, config=False) main_file = get_main(extracted_dir) # If no main TeX file is found, remove the downloaded archive and continue if main_file == None: print("No tex files found") os.remove(tar_file_path) continue # Check if the main TeX file contains a valid LaTeX document h = check_begin(main_file) if h == True: results["Number of latex papers"] += 1 # Flag to check for internal bibliography check_internal = 0 # Dictionary to store bibliographic references final_library = {} # Identify bibliography files (.bib or .bbl) bib_files = find_bib(extracted_dir) if bib_files == []: bbl_files = find_bbl(extracted_dir) if bbl_files == []: # No external bibliography found check_internal = 1 else: final_library = get_library_bbl(bbl_files) results["Number of bbl files"] += 1 else: results["Number of bib files"] += 1 final_library = get_library_bib(bib_files) # Apply post-processing to clean the TeX document main_file = post_processing(extracted_dir, main_file) # Read the cleaned LaTeX document content descr = main_file content = read_tex_file(descr) # If configured, store the raw content in the graph if config['data_downloading']['processing']['keep_unstructured_content']: graph[paper_name] = {'content': content} else: graph[paper_name] = {} # Check for inline bibliography within the LaTeX document if check_internal == 1: beginning_bib = '\\begin{thebibliography}' end_bib = '\\end{thebibliography}' if content.find(beginning_bib) != -1 and content.find(end_bib) != -1: bibliography = content[content.find(beginning_bib):content.find(end_bib) + len(end_bib)] save_bbl = os.path.join(extracted_dir, "bibliography.bbl") results["Number of inline files"] += 1 with open(save_bbl, "w") as f: f.write(bibliography) final_library = get_library_bbl([save_bbl]) # If no valid bibliography is found, skip processing citations if final_library == {}: print("No library found...") continue # Extract relevant sections such as "Related Work" and "Introduction" related_works = get_related_works(content) if related_works != '': graph[paper_name]['Related Work'] = related_works results["Number of intro/related found"] += 1 intro = get_intro(content) if intro != '': graph[paper_name]['Introduction'] = intro results["Number of introductions found"] += 1 # Extract citation sentences from the introduction and related works sentences_citing = get_citing_sentences(intro + '\n' + related_works) # Map citations to corresponding papers raw_sentences_citing = {} for k,v in sentences_citing.items(): new_values = [] for item in v: try: new_values.append(paper_dict[final_library[item]['title']]) except Exception as e: pass if new_values != []: raw_sentences_citing[k] = new_values # Construct citation edges edges_set = [] for k,v in raw_sentences_citing.items(): for item in v: edges_set.append((paper_name_download, item, {"sentence":k})) iter_ind +=1 if len(edges_set) !=0: results["Number of succesful finding of extracts"] += 1 graph[paper_name]['Citations'] = edges_set num_edges += len(edges_set) # Save progress after every 10 iterations if iter_ind % 10 == 0: print("Saving graph now") with open(save_path, 'w') as f: json.dump(results, f) with open(save_graph, 'w') as f: json.dump(graph, f) except Exception as e: print("Could not get main paper {}".format(paper_name)) # Update the progress bar after processing each paper download_progress_bar(num_papers / total_papers) # Ensure a minimum time gap of 3 seconds between iterations to avoid bans from arXiv t2 = time.time() # End time elapsed_time = t2 - t1 if elapsed_time < arxiv_rate_lim: time.sleep(arxiv_rate_lim - elapsed_time) # Final saving of processed data with open(save_graph, 'w') as f: json.dump(graph, f) with open(save_path, 'w') as f: json.dump(results, f) # Log final completion message llm_resp.append("✅ Successfully downloaded and cleaned {} papers.".format(results["Number of latex papers"])) return "\n".join(llm_resp) # Chat prediction function @spaces.GPU(duration=120) def predict(message, history, selected_task): global model # Initialize the conversation string conversation = "" # Parse the history: Gradio `type="messages"` uses dictionaries with 'role' and 'content' for item in history: if item["role"] == "assistant": conversation += f": {item['content']}\n" elif item["role"] == "user": conversation += f": {item['content']}\n" # Add the user's current message to the conversation conversation += f": {message}\n:" # Handle preferences if len(history) == 0: if not download_papers.value and not train_model.value: yield "✅ Using model from configuration file..." adapter_path = config["inference"]["pretrained_model"] peft_model = PeftModel.from_pretrained(model, adapter_path, torch_dtype=torch.float16) # change the global model with peft model model = peft_model time.sleep(2.5) if not (len(history) == 0 and (train_model.value or download_papers.value)): # Streamer for generating responses streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) stop = StopOnTokens() generate_kwargs = { "streamer": streamer, "max_new_tokens": config['inference']['generation_args']["max_new_tokens"], "do_sample": config['inference']['generation_args']["do_sample"], "top_p": config['inference']['generation_args']["top_p"], "top_k": config['inference']['generation_args']["top_k"], "temperature": config['inference']['generation_args']["temperature"], "no_repeat_ngram_size": config['inference']['generation_args']["no_repeat_ngram_size"], "num_beams": config['inference']['generation_args']["num_beams"], "stopping_criteria": StoppingCriteriaList([stop]), } def generate_response(model, generate_kwargs, selected_task): global advanced_tasks_out has_predefined_template = generate_kwargs["streamer"].tokenizer.chat_template is not None if selected_task == "Abstract Completion": prompt = abs_completion(message, template, has_predefined_template) elif selected_task == "Title Generation": prompt = abs_2_title(message, template, has_predefined_template) elif selected_task == "Citation Recommendation": prompt = paper_retrieval(message, template, has_predefined_template) elif selected_task == "Citation Sentence Generation": prompt = citation_sentence(message, template, has_predefined_template) elif selected_task == "Citation Link Prediction": prompt = link_pred(message, template, has_predefined_template) elif selected_task == "Introduction to Abstract": prompt = intro_2_abs(message, template, tokenizer.model_max_length, has_predefined_template) elif selected_task == "Influential Papers Recommendation": if download_papers.value: graph = nx.read_gexf(gexf_file) advanced_tasks_out = influential_papers(message, graph) else: graph = nx.read_gexf(predef_graph) advanced_tasks_out = influential_papers(message, graph) elif selected_task == "Related Work Generation": adapter_path = ( f"{config['model_saving']['model_output_dir']}/{config['model_saving']['model_name']}_{config['model_saving']['index']}_adapter_test_graph" if train_model.value else config['inference']['pretrained_model'] ) if download_papers.value: advanced_tasks_out = gen_related_work(message, gexf_file, adapter_path) else: advanced_tasks_out = gen_related_work(message, predef_graph, adapter_path) else: prompt = conversation + f": {message}\n:" if selected_task != "Influential Papers Recommendation" and selected_task != "Related Work Generation": if tokenizer.chat_template is not None: response = model_pipeline(prompt, **generate_kwargs) streamer.put(response[0]['generated_text'][-1]) else: model_inputs = tokenizer(prompt, return_tensors="pt").to("cuda") generate_kwargs["inputs"] = model_inputs["input_ids"] generate_kwargs["attention_mask"] = model_inputs["attention_mask"] response = model.generate(**generate_kwargs) streamer.put(response) # Generate the response in a separate thread t = Thread(target=generate_response, kwargs={ "model": model, "generate_kwargs": generate_kwargs, "selected_task": selected_task }) global advanced_tasks_out advanced_tasks_out = None t.start() # Stream the partial response if selected_task != "Influential Papers Recommendation" and selected_task != "Related Work Generation": partial_message = "" for new_token in streamer: if new_token != '<': # Ignore placeholder tokens partial_message += new_token yield partial_message else: if selected_task == "Related Work Generation": yield "🔍 Generating related work..." while advanced_tasks_out == None: time.sleep(0.1) yield advanced_tasks_out # Fetch arXiv papers if the user opted to download them if len(history) == 0: if download_papers.value: # Fetch relevant papers yield "🔍 Retrieving relevant papers..." retrieve_progress = gr.Progress() for percent in retriever(message, retrieval_nodes_path): retrieve_progress(percent) with open(retrieval_nodes_path, "r") as f: data_download = json.load(f) papers_to_download = list(data_download.keys()) yield f"📥 Fetching {len(papers_to_download)} arXiv papers' source files... Please wait." content = fetch_arxiv_papers(papers_to_download) yield content time.sleep(2.5) # Train the model with the retrieved graph if len(history) == 0: if train_model.value: training_progress=gr.Progress() training_progress(0.0) # If the user opted to download papers, use the retrieved graph, else use the predefined graph if download_papers.value: yield "🚀 Training the model with the retrieved graph..." with open(save_graph, "r") as f: data_graph = json.load(f) renamed_data = { "/".join(re.match(r"([a-z-]+)([0-9]+)", key, re.I).groups()) if re.match(r"([a-z-]+)([0-9]+)", key, re.I) else key: value for key, value in data_graph.items() } concept_data = load_dataset("AliMaatouk/arXiv_Topics", cache_dir="datasets/arxiv_topics") id2topics = { entry["paper_id"]: [entry["Level 1"], entry["Level 2"], entry["Level 3"]] for entry in concept_data["train"] } dataset = 'datasets/arxiv-metadata-oai-snapshot.json' data = [] if not os.path.exists(dataset): os.system("wget https://huggingface.co/spaces/ddiddu/simsearch/resolve/main/arxiv-metadata-oai-snapshot.json -P ./datasets") with open(dataset, 'r') as f: for line in f: data.append(json.loads(line)) papers = {d['id']: d for d in data} G = nx.DiGraph() for k in renamed_data: if k not in G and k in papers: if config['data_downloading']['processing']['keep_unstructured_content']: G.add_node( k, title=papers[k]['title'], abstract=papers[k]['abstract'], introduction=renamed_data[k].get('Introduction', '') if renamed_data[k].get('Introduction', '') != '\n' else '', related=renamed_data[k].get('Related Work', '') if renamed_data[k].get('Related Work', '') != '\n' else '', concepts=", ".join(list(set(item for sublist in id2topics[k] for item in sublist))) if k in id2topics else '', content=renamed_data[k].get('content', '') if k in renamed_data else '' ) else: G.add_node( k, title=papers[k]['title'], abstract=papers[k]['abstract'], introduction=renamed_data[k].get('Introduction', '') if renamed_data[k].get('Introduction', '') != '\n' else '', related=renamed_data[k].get('Related Work', '') if renamed_data[k].get('Related Work', '') != '\n' else '', concepts=", ".join(list(set(item for sublist in id2topics[k] for item in sublist))) if k in id2topics else '' ) if 'Citations' in renamed_data[k]: for citation in renamed_data[k]['Citations']: source, target, metadata = citation sentence = metadata.get('sentence', '') # Extract sentence or default to empty string if target not in G and target in papers: if config['data_downloading']['processing']['keep_unstructured_content']: G.add_node( target, title=papers[target]['title'], abstract=papers[target]['abstract'], introduction=renamed_data[target].get('Introduction', '') if target in renamed_data and renamed_data[target].get('Introduction', '') != '\n' else '', related=renamed_data[target].get('Related Work', '') if target in renamed_data and renamed_data[target].get('Related Work', '') != '\n' else '', concepts=", ".join(list(set(item for sublist in concept_data[target].values() for item in sublist))) if target in concept_data else '', content=renamed_data[target].get('content', '') if target in renamed_data else '' ) else: G.add_node( target, title=papers[target]['title'], abstract=papers[target]['abstract'], introduction=renamed_data[target].get('Introduction', '') if target in renamed_data and renamed_data[target].get('Introduction', '') != '\n' else '', related=renamed_data[target].get('Related Work', '') if target in renamed_data and renamed_data[target].get('Related Work', '') != '\n' else '', concepts=", ".join(list(set(item for sublist in concept_data[target].values() for item in sublist))) if target in concept_data else '' ) G.add_edge(source, target, sentence=sentence) G.remove_nodes_from(list(nx.isolates(G))) nx.write_gexf(G, gexf_file) print(f"Processed graph written to {gexf_file}") else: yield f"✅ Using predefined graph: {predef_graph}" wandb.init(project='qlora_train') if download_papers.value: trainer = QloraTrainer_CS(config=config, use_predefined_graph=False) else: trainer = QloraTrainer_CS(config=config, use_predefined_graph=True) print("Load base model") trainer.load_base_model() print("Start training") def update_progress(): # Wait for the trainer to be initialized while trainer.transformer_trainer is None: time.sleep(0.5) time.sleep(1.5) # Update the progress bar until training is complete while trainer.transformer_trainer.state.global_step != trainer.transformer_trainer.state.max_steps: progress_bar = ( trainer.transformer_trainer.state.global_step / trainer.transformer_trainer.state.max_steps ) training_progress(progress_bar) time.sleep(0.5) training_progress(1.0) t1 = Thread(target=trainer.train) t1.start() t2 = Thread(target=update_progress()) t2.start() t1.join() t2.join() yield "🎉 Model training complete! Please provide your task prompt." adapter_path = f"{config['model_saving']['model_output_dir']}/{config['model_saving']['model_name']}_{config['model_saving']['index']}_adapter_test_graph" peft_model = PeftModel.from_pretrained(model, adapter_path, torch_dtype=torch.float16) # change the global model with peft model model = peft_model if __name__ == "__main__": print("This is running in a virtual environment: {}".format(is_venv())) config = read_yaml_file("configs/config.yaml") template_file_path = 'configs/alpaca.json' template = json.load(open(template_file_path, "r")) seed_no = config['data_downloading']['processing']['random_seed'] model_name = config['inference']['base_model'] working_dir = config['data_downloading']['download_directory'] save_zip_directory = working_dir + 'research_papers_zip/' save_directory = working_dir + 'research_papers/' save_description = working_dir + 'description/' save_path = save_description + 'results.json' save_graph = save_description + 'test_graph.json' gexf_file = save_description + config['data_downloading']['gexf_file'] predef_graph = 'datasets/' + config['training']['predefined_graph_path'] retrieval_nodes_path = 'datasets/retrieval_nodes.json' isExist = os.path.exists(save_zip_directory) if not isExist: os.makedirs(save_zip_directory) isExist = os.path.exists(save_directory) if not isExist: os.makedirs(save_directory) isExist = os.path.exists(save_description) if not isExist: os.makedirs(save_description) random.seed(seed_no) # Load model and tokenizer bnb_config = BitsAndBytesConfig( load_in_8bit=True, bnb_8bit_use_double_quant=True, bnb_8bit_quant_type="nf8", bnb_8bit_compute_dtype=torch.bfloat16 ) # Load model and tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config) if model.device.type != 'cuda': model.to('cuda') if tokenizer.chat_template is not None: model_pipeline = pipeline( "text-generation", model=model_name, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", ) signal.signal(signal.SIGINT, signal_handler) # Global States for User Preferences download_papers = gr.State(value=True) # Default: Download papers train_model = gr.State(value=True) # Default: Train the model # Categorized Recommended Prompts task_list = { "Abstract Completion", "Introduction to Abstract", "Title Generation", "Citation Recommendation", "Citation Sentence Generation", "Citation Link Prediction", "Influential Papers Recommendation", "Related Work Generation", } # CSS for Styling css = """ body { background-color: #E0F7FA; margin: 0; padding: 0; } .gradio-container { background-color: #E0F7FA; border-radius: 10px; } #logo-container { display: flex; justify-content: center; align-items: center; margin: 0 auto; padding: 0; max-width: 120px; height: 120px; border-radius: 10px; overflow: hidden; } #scroll-menu { max-height: 310px; overflow-y: auto; padding: 10px; background-color: #fff; margin-top: 10px;} #task-header { background-color: #0288d1; color: white; font-size: 18px; padding: 8px; text-align: center; margin-bottom: 5px; margin-top: 40px; } #category-header { background-color: #ecb939; font-size: 16px; padding: 8px; margin: 10px 0; } """ # State to store the selected task selected_task = gr.State(value="") # Gradio Interface with gr.Blocks(theme="soft", css=css) as demo: gr.HTML('
Logo
') gr.Markdown("# LitBench Interface") # Setup row for user preferences with gr.Row(visible=True) as setup_row: with gr.Column(): gr.Markdown("### Setup Your Preferences") download_option = gr.Dropdown( choices=["Download Paper", "Don't Download"], value="Download Paper", label="Download Option" ) train_option = gr.Dropdown( choices=["Train", "Don't Train"], value="Train", label="Training Option" ) setup_button = gr.Button("Set Preferences and Proceed") # Chatbot row for user interaction with gr.Row(visible=False) as chatbot_row: # Store the currently selected task with gr.Column(scale=3): gr.Markdown("### Start Chatting!") chatbot = gr.ChatInterface( predict, chatbot=gr.Chatbot( height=400, type="messages", avatar_images=[ "https://icons.veryicon.com/png/o/miscellaneous/user-avatar/user-avatar-male-5.png", "https://cdn-icons-png.flaticon.com/512/8649/8649595.png" ], ), textbox=gr.Textbox(placeholder="Type your message here..."), additional_inputs=selected_task, additional_inputs_accordion=gr.Accordion(visible=False, label="Additional Inputs", ), ) # Store user preferences and selected task for display preferences_output = gr.Textbox(value="", interactive=False, label="Your Preferences") # Task selection buttons for user interaction with gr.Column(scale=1): gr.HTML('
Tasks:
') with gr.Column(elem_id="scroll-menu"): # Create buttons button_map = {prompt: gr.Button(prompt) for prompt in task_list} for prompt in task_list: button_map[prompt].click( toggle_selection, inputs=[selected_task, gr.State(value=prompt)], # Toggle task selection outputs=selected_task ).then( update_button_styles, # Update button appearances inputs=[selected_task], outputs=[button_map[p] for p in task_list] # Update all buttons ) # Setup button to finalize user preferences and start chatbot setup_button.click( setup, inputs=[download_option, train_option], outputs=[setup_row, chatbot_row, preferences_output, chatbot.chatbot] ) # Launch the interface demo.launch(share=False, debug=False)