Andreas Varvarigos commited on
Commit
72222ce
·
verified ·
1 Parent(s): 91e5195

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -746
app.py DELETED
@@ -1,746 +0,0 @@
1
- from train import *
2
- from utils.utils import *
3
- from utils.graph_utils import *
4
- from utils.gradio_utils import *
5
- from retriever.retriever import retriever
6
- from tasks.abs_2_title import abs_2_title
7
- from tasks.abs_completion import abs_completion
8
- from tasks.citation_sentence import citation_sentence
9
- from tasks.intro_2_abs import intro_2_abs
10
- from tasks.link_pred import link_pred
11
- from tasks.paper_retrieval import paper_retrieval
12
- from tasks.influential_papers import influential_papers
13
- from tasks.gen_related_work import gen_related_work
14
- import random
15
- import json
16
- import os
17
- import re
18
- import networkx as nx
19
- import tarfile
20
- import gzip
21
- import time
22
- import urllib.request
23
- from tqdm import tqdm
24
- from colorama import Fore
25
- import wandb
26
- import gradio as gr
27
- from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteriaList, TextIteratorStreamer, pipeline
28
- from threading import Thread
29
- import signal
30
- import gzip
31
- import time
32
- import torch
33
- from peft.peft_model import PeftModel
34
- from datasets import load_dataset
35
-
36
-
37
-
38
- # Function to determine the chatbot's first message based on user choices
39
- def setup(download_option, train_option):
40
- download_papers.value = (download_option == "Download Paper")
41
- train_model.value = (train_option == "Train")
42
-
43
- if download_option == "Download Paper":
44
- initial_message = [{"role": "assistant", "content": "Hello, what domain are you interested in?"}]
45
- elif download_option != "Download Paper" and train_option == "Train":
46
- initial_message = [{"role": "assistant", "content": "What domain is your graph about?"}]
47
- else:
48
- initial_message = [{"role": "assistant", "content": "Please provide your task prompt."}]
49
-
50
- return gr.update(visible=False), gr.update(visible=True), f"Download: {download_option}\nTrain: {train_option}", initial_message
51
-
52
-
53
- # Function to toggle the selected task based on user input
54
- def update_button_styles(selected_task):
55
- """Update button styles based on selection."""
56
- return [gr.update(variant="primary" if selected_task == prompt else "secondary") for prompt in task_list]
57
-
58
-
59
- # Fetch and store arXiv source files
60
- def fetch_arxiv_papers(papers_to_download):
61
- # Download the arXiv metadata file if it doesn't exist
62
- dataset = 'datasets/arxiv-metadata-oai-snapshot.json'
63
- data = []
64
- if not os.path.exists(dataset):
65
- os.system("wget https://huggingface.co/spaces/ddiddu/simsearch/resolve/main/arxiv-metadata-oai-snapshot.json -P ./datasets")
66
-
67
- with open(dataset, 'r') as f:
68
- for line in f:
69
- data.append(json.loads(line))
70
-
71
- papers = [d for d in data]
72
- paper_ids = [d['id'] for d in data]
73
- paper_titles = [
74
- (
75
- re.sub(r' +', ' ', re.sub(r'[\n]+', ' ', paper['title']))
76
- .replace("\\emph", "")
77
- .replace("\\emp", "")
78
- .replace("\\em", "")
79
- .replace(",", "")
80
- .replace("{", "")
81
- .replace("}", "")
82
- .strip(".")
83
- .strip()
84
- .strip(".")
85
- .lower()
86
- )
87
- for paper in papers
88
- ]
89
- paper_dict = {
90
- k:v
91
- for k,v in zip(paper_titles, paper_ids)
92
- }
93
-
94
-
95
- total_papers = len(papers_to_download)
96
- download_progress_bar=gr.Progress()
97
-
98
- llm_resp = []
99
- results = {
100
- "Number of papers": 0,
101
- "Number of latex papers": 0,
102
- "Number of bib files": 0,
103
- "Number of bbl files": 0,
104
- "Number of inline files": 0,
105
- "Number of introductions found": 0,
106
- "Number of related works found": 0,
107
- "Number of succesful finding of extracts": 0
108
- }
109
- num_papers, num_edges, t, iter_ind = 0, 0, 0, 0
110
- graph = {}
111
-
112
- arxiv_rate_lim = config['data_downloading']['processing']['arxiv_rate_limit']
113
- for paper_name in tqdm(papers_to_download):
114
- results["Number of papers"] += 1
115
- print(
116
- Fore.BLUE + "Number of papers processed: {} \n Number of edges found: {} \n Time of previous iter: {} \n Now processing paper: {} \n\n"
117
- .format(num_papers, num_edges, time.time()-t, paper_name) + Fore.RESET
118
- )
119
- t = time.time()
120
- num_papers += 1
121
-
122
- # Prepare the paper name for downloading and saving
123
- paper_name_download = paper_name
124
- if re.search(r'[a-zA-Z]', paper_name) is not None:
125
- paper_name = "".join(paper_name.split('/'))
126
- tar_file_path = save_zip_directory + paper_name + '.tar.gz'
127
-
128
- # Attempt to download the paper source files from arXiv
129
- try:
130
- # Track start time for download
131
- t1 = time.time()
132
- urllib.request.urlretrieve(
133
- "https://arxiv.org/src/"+paper_name_download,
134
- tar_file_path)
135
- except Exception as e:
136
- print("Couldn't download paper {}".format(paper_name))
137
- # Skip to the next paper if download fails
138
- continue
139
-
140
- # Define the directory where the paper will be extracted
141
- extracted_dir = save_directory + paper_name + '/'
142
- isExist = os.path.exists(extracted_dir)
143
- if not isExist:
144
- os.makedirs(extracted_dir)
145
-
146
- # Attempt to extract the tar.gz archive
147
- try:
148
- tar = tarfile.open(tar_file_path)
149
- tar.extractall(extracted_dir)
150
- tar.close()
151
- except Exception as e:
152
- # If tar extraction fails, attempt to read and extract using gzip
153
- try:
154
- with gzip.open(tar_file_path, 'rb') as f:
155
- file_content = f.read()
156
-
157
- # Save the extracted content as a .tex file
158
- with open(extracted_dir+paper_name+'.tex', 'w') as f:
159
- f.write(file_content.decode())
160
- except Exception as e:
161
- print("Could not extract paper id: {}".format(paper_name))
162
- # Skip this paper if extraction fails
163
- continue
164
-
165
- try:
166
- # Perform initial cleaning and get the main TeX file
167
- initial_clean(extracted_dir, config=False)
168
- main_file = get_main(extracted_dir)
169
-
170
- # If no main TeX file is found, remove the downloaded archive and continue
171
- if main_file == None:
172
- print("No tex files found")
173
- os.remove(tar_file_path)
174
- continue
175
-
176
- # Check if the main TeX file contains a valid LaTeX document
177
- h = check_begin(main_file)
178
- if h == True:
179
- results["Number of latex papers"] += 1
180
- # Flag to check for internal bibliography
181
- check_internal = 0
182
- # Dictionary to store bibliographic references
183
- final_library = {}
184
-
185
- # Identify bibliography files (.bib or .bbl)
186
- bib_files = find_bib(extracted_dir)
187
- if bib_files == []:
188
- bbl_files = find_bbl(extracted_dir)
189
- if bbl_files == []:
190
- # No external bibliography found
191
- check_internal = 1
192
- else:
193
- final_library = get_library_bbl(bbl_files)
194
- results["Number of bbl files"] += 1
195
- else:
196
- results["Number of bib files"] += 1
197
- final_library = get_library_bib(bib_files)
198
-
199
- # Apply post-processing to clean the TeX document
200
- main_file = post_processing(extracted_dir, main_file)
201
-
202
- # Read the cleaned LaTeX document content
203
- descr = main_file
204
- content = read_tex_file(descr)
205
-
206
- # If configured, store the raw content in the graph
207
- if config['data_downloading']['processing']['keep_unstructured_content']:
208
- graph[paper_name] = {'content': content}
209
- else:
210
- graph[paper_name] = {}
211
-
212
- # Check for inline bibliography within the LaTeX document
213
- if check_internal == 1:
214
- beginning_bib = '\\begin{thebibliography}'
215
- end_bib = '\\end{thebibliography}'
216
-
217
- if content.find(beginning_bib) != -1 and content.find(end_bib) != -1:
218
- bibliography = content[content.find(beginning_bib):content.find(end_bib) + len(end_bib)]
219
- save_bbl = os.path.join(extracted_dir, "bibliography.bbl")
220
-
221
- results["Number of inline files"] += 1
222
- with open(save_bbl, "w") as f:
223
- f.write(bibliography)
224
-
225
- final_library = get_library_bbl([save_bbl])
226
-
227
- # If no valid bibliography is found, skip processing citations
228
- if final_library == {}:
229
- print("No library found...")
230
- continue
231
-
232
- # Extract relevant sections such as "Related Work" and "Introduction"
233
- related_works = get_related_works(content)
234
- if related_works != '':
235
- graph[paper_name]['Related Work'] = related_works
236
- results["Number of intro/related found"] += 1
237
-
238
- intro = get_intro(content)
239
- if intro != '':
240
- graph[paper_name]['Introduction'] = intro
241
- results["Number of introductions found"] += 1
242
-
243
- # Extract citation sentences from the introduction and related works
244
- sentences_citing = get_citing_sentences(intro + '\n' + related_works)
245
-
246
- # Map citations to corresponding papers
247
- raw_sentences_citing = {}
248
- for k,v in sentences_citing.items():
249
- new_values = []
250
- for item in v:
251
- try:
252
- new_values.append(paper_dict[final_library[item]['title']])
253
- except Exception as e:
254
- pass
255
- if new_values != []:
256
- raw_sentences_citing[k] = new_values
257
-
258
- # Construct citation edges
259
- edges_set = []
260
- for k,v in raw_sentences_citing.items():
261
- for item in v:
262
- edges_set.append((paper_name_download, item, {"sentence":k}))
263
-
264
- iter_ind +=1
265
- if len(edges_set) !=0:
266
- results["Number of succesful finding of extracts"] += 1
267
- graph[paper_name]['Citations'] = edges_set
268
- num_edges += len(edges_set)
269
-
270
- # Save progress after every 10 iterations
271
- if iter_ind % 10 == 0:
272
- print("Saving graph now")
273
- with open(save_path, 'w') as f:
274
- json.dump(results, f)
275
- with open(save_graph, 'w') as f:
276
- json.dump(graph, f)
277
-
278
- except Exception as e:
279
- print("Could not get main paper {}".format(paper_name))
280
-
281
- # Update the progress bar after processing each paper
282
- download_progress_bar(num_papers / total_papers)
283
-
284
-
285
- # Ensure a minimum time gap of 3 seconds between iterations to avoid bans from arXiv
286
- t2 = time.time() # End time
287
- elapsed_time = t2 - t1
288
- if elapsed_time < arxiv_rate_lim:
289
- time.sleep(arxiv_rate_lim - elapsed_time)
290
-
291
-
292
- # Final saving of processed data
293
- with open(save_graph, 'w') as f:
294
- json.dump(graph, f)
295
- with open(save_path, 'w') as f:
296
- json.dump(results, f)
297
-
298
-
299
- # Log final completion message
300
- llm_resp.append("✅ Successfully downloaded and cleaned {} papers.".format(results["Number of latex papers"]))
301
- return "\n".join(llm_resp)
302
-
303
-
304
- # Chat prediction function
305
- def predict(message, history, selected_task):
306
- global model
307
- # Initialize the conversation string
308
- conversation = ""
309
-
310
- # Parse the history: Gradio `type="messages"` uses dictionaries with 'role' and 'content'
311
- for item in history:
312
- if item["role"] == "assistant":
313
- conversation += f"<bot>: {item['content']}\n"
314
- elif item["role"] == "user":
315
- conversation += f"<human>: {item['content']}\n"
316
-
317
- # Add the user's current message to the conversation
318
- conversation += f"<human>: {message}\n<bot>:"
319
-
320
- # Handle preferences
321
- if len(history) == 0:
322
- if not download_papers.value and not train_model.value:
323
- yield "✅ Using model from configuration file..."
324
-
325
- adapter_path = config["inference"]["pretrained_model"]
326
- peft_model = PeftModel.from_pretrained(model, adapter_path, torch_dtype=torch.float16)
327
-
328
- # change the global model with peft model
329
- model = peft_model
330
-
331
- time.sleep(2.5)
332
-
333
- if not (len(history) == 0 and (train_model.value or download_papers.value)):
334
- # Streamer for generating responses
335
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
336
- stop = StopOnTokens()
337
-
338
- generate_kwargs = {
339
- "streamer": streamer,
340
- "max_new_tokens": config['inference']['generation_args']["max_new_tokens"],
341
- "do_sample": config['inference']['generation_args']["do_sample"],
342
- "top_p": config['inference']['generation_args']["top_p"],
343
- "top_k": config['inference']['generation_args']["top_k"],
344
- "temperature": config['inference']['generation_args']["temperature"],
345
- "no_repeat_ngram_size": config['inference']['generation_args']["no_repeat_ngram_size"],
346
- "num_beams": config['inference']['generation_args']["num_beams"],
347
- "stopping_criteria": StoppingCriteriaList([stop]),
348
- }
349
-
350
- def generate_response(model, generate_kwargs, selected_task):
351
- global advanced_tasks_out
352
- has_predefined_template = generate_kwargs["streamer"].tokenizer.chat_template is not None
353
-
354
- if selected_task == "Abstract Completion":
355
- prompt = abs_completion(message, template, has_predefined_template)
356
- elif selected_task == "Title Generation":
357
- prompt = abs_2_title(message, template, has_predefined_template)
358
- elif selected_task == "Citation Recommendation":
359
- prompt = paper_retrieval(message, template, has_predefined_template)
360
- elif selected_task == "Citation Sentence Generation":
361
- prompt = citation_sentence(message, template, has_predefined_template)
362
- elif selected_task == "Citation Link Prediction":
363
- prompt = link_pred(message, template, has_predefined_template)
364
- elif selected_task == "Introduction to Abstract":
365
- prompt = intro_2_abs(message, template, tokenizer.model_max_length, has_predefined_template)
366
- elif selected_task == "Influential Papers Recommendation":
367
- if download_papers.value:
368
- graph = nx.read_gexf(gexf_file)
369
- advanced_tasks_out = influential_papers(message, graph)
370
- else:
371
- graph = nx.read_gexf(predef_graph)
372
- advanced_tasks_out = influential_papers(message, graph)
373
- elif selected_task == "Related Work Generation":
374
- adapter_path = (
375
- f"{config['model_saving']['model_output_dir']}/{config['model_saving']['model_name']}_{config['model_saving']['index']}_adapter_test_graph"
376
- if train_model.value else config['inference']['pretrained_model']
377
- )
378
- if download_papers.value:
379
- advanced_tasks_out = gen_related_work(message, gexf_file, adapter_path)
380
- else:
381
- advanced_tasks_out = gen_related_work(message, predef_graph, adapter_path)
382
- else:
383
- prompt = conversation + f"<human>: {message}\n<bot>:"
384
-
385
- if selected_task != "Influential Papers Recommendation" and selected_task != "Related Work Generation":
386
- if tokenizer.chat_template is not None:
387
- response = model_pipeline(prompt, **generate_kwargs)
388
- streamer.put(response[0]['generated_text'][-1])
389
- else:
390
- model_inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
391
- generate_kwargs["inputs"] = model_inputs["input_ids"]
392
- generate_kwargs["attention_mask"] = model_inputs["attention_mask"]
393
-
394
- response = model.generate(**generate_kwargs)
395
- streamer.put(response)
396
-
397
- # Generate the response in a separate thread
398
- t = Thread(target=generate_response,
399
- kwargs={
400
- "model": model,
401
- "generate_kwargs": generate_kwargs,
402
- "selected_task": selected_task
403
- })
404
-
405
- global advanced_tasks_out
406
- advanced_tasks_out = None
407
- t.start()
408
-
409
- # Stream the partial response
410
- if selected_task != "Influential Papers Recommendation" and selected_task != "Related Work Generation":
411
- partial_message = ""
412
- for new_token in streamer:
413
- if new_token != '<': # Ignore placeholder tokens
414
- partial_message += new_token
415
- yield partial_message
416
- else:
417
- if selected_task == "Related Work Generation":
418
- yield "🔍 Generating related work..."
419
- while advanced_tasks_out == None:
420
- time.sleep(0.1)
421
- yield advanced_tasks_out
422
-
423
- # Fetch arXiv papers if the user opted to download them
424
- if len(history) == 0:
425
- if download_papers.value:
426
- # Fetch relevant papers
427
- yield "🔍 Retrieving relevant papers..."
428
-
429
- retrieve_progress = gr.Progress()
430
- for percent in retriever(message, retrieval_nodes_path):
431
- retrieve_progress(percent)
432
-
433
- with open(retrieval_nodes_path, "r") as f:
434
- data_download = json.load(f)
435
-
436
- papers_to_download = list(data_download.keys())
437
-
438
- yield f"📥 Fetching {len(papers_to_download)} arXiv papers' source files... Please wait."
439
-
440
- content = fetch_arxiv_papers(papers_to_download)
441
- yield content
442
- time.sleep(2.5)
443
-
444
-
445
- # Train the model with the retrieved graph
446
- if len(history) == 0:
447
- if train_model.value:
448
- training_progress=gr.Progress()
449
-
450
- training_progress(0.0)
451
-
452
- # If the user opted to download papers, use the retrieved graph, else use the predefined graph
453
- if download_papers.value:
454
- yield "🚀 Training the model with the retrieved graph..."
455
-
456
- with open(save_graph, "r") as f:
457
- data_graph = json.load(f)
458
-
459
- renamed_data = {
460
- "/".join(re.match(r"([a-z-]+)([0-9]+)", key, re.I).groups()) if re.match(r"([a-z-]+)([0-9]+)", key, re.I) else key: value
461
- for key, value in data_graph.items()
462
- }
463
-
464
- concept_data = load_dataset("AliMaatouk/arXiv_Topics", cache_dir="datasets/arxiv_topics")
465
- id2topics = {
466
- entry["paper_id"]: [entry["Level 1"], entry["Level 2"], entry["Level 3"]]
467
- for entry in concept_data["train"]
468
- }
469
-
470
- dataset = 'datasets/arxiv-metadata-oai-snapshot.json'
471
- data = []
472
- if not os.path.exists(dataset):
473
- os.system("wget https://huggingface.co/spaces/ddiddu/simsearch/resolve/main/arxiv-metadata-oai-snapshot.json -P ./datasets")
474
- with open(dataset, 'r') as f:
475
- for line in f:
476
- data.append(json.loads(line))
477
- papers = {d['id']: d for d in data}
478
-
479
- G = nx.DiGraph()
480
- for k in renamed_data:
481
- if k not in G and k in papers:
482
- if config['data_downloading']['processing']['keep_unstructured_content']:
483
- G.add_node(
484
- k,
485
- title=papers[k]['title'],
486
- abstract=papers[k]['abstract'],
487
- introduction=renamed_data[k].get('Introduction', '') if renamed_data[k].get('Introduction', '') != '\n' else '',
488
- related=renamed_data[k].get('Related Work', '') if renamed_data[k].get('Related Work', '') != '\n' else '',
489
- concepts=", ".join(list(set(item for sublist in id2topics[k] for item in sublist))) if k in id2topics else '',
490
- content=renamed_data[k].get('content', '') if k in renamed_data else ''
491
- )
492
- else:
493
- G.add_node(
494
- k,
495
- title=papers[k]['title'],
496
- abstract=papers[k]['abstract'],
497
- introduction=renamed_data[k].get('Introduction', '') if renamed_data[k].get('Introduction', '') != '\n' else '',
498
- related=renamed_data[k].get('Related Work', '') if renamed_data[k].get('Related Work', '') != '\n' else '',
499
- concepts=", ".join(list(set(item for sublist in id2topics[k] for item in sublist))) if k in id2topics else ''
500
- )
501
- if 'Citations' in renamed_data[k]:
502
- for citation in renamed_data[k]['Citations']:
503
- source, target, metadata = citation
504
- sentence = metadata.get('sentence', '') # Extract sentence or default to empty string
505
-
506
- if target not in G and target in papers:
507
- if config['data_downloading']['processing']['keep_unstructured_content']:
508
- G.add_node(
509
- target,
510
- title=papers[target]['title'],
511
- abstract=papers[target]['abstract'],
512
- introduction=renamed_data[target].get('Introduction', '') if target in renamed_data and renamed_data[target].get('Introduction', '') != '\n' else '',
513
- related=renamed_data[target].get('Related Work', '') if target in renamed_data and renamed_data[target].get('Related Work', '') != '\n' else '',
514
- concepts=", ".join(list(set(item for sublist in concept_data[target].values() for item in sublist))) if target in concept_data else '',
515
- content=renamed_data[target].get('content', '') if target in renamed_data else ''
516
- )
517
- else:
518
- G.add_node(
519
- target,
520
- title=papers[target]['title'],
521
- abstract=papers[target]['abstract'],
522
- introduction=renamed_data[target].get('Introduction', '') if target in renamed_data and renamed_data[target].get('Introduction', '') != '\n' else '',
523
- related=renamed_data[target].get('Related Work', '') if target in renamed_data and renamed_data[target].get('Related Work', '') != '\n' else '',
524
- concepts=", ".join(list(set(item for sublist in concept_data[target].values() for item in sublist))) if target in concept_data else ''
525
- )
526
-
527
- G.add_edge(source, target, sentence=sentence)
528
-
529
- G.remove_nodes_from(list(nx.isolates(G)))
530
-
531
- nx.write_gexf(G, gexf_file)
532
- print(f"Processed graph written to {gexf_file}")
533
- else:
534
- yield f"✅ Using predefined graph: {predef_graph}"
535
-
536
-
537
- wandb.init(project='qlora_train')
538
-
539
- if download_papers.value:
540
- trainer = QloraTrainer_CS(config=config, use_predefined_graph=False)
541
- else:
542
- trainer = QloraTrainer_CS(config=config, use_predefined_graph=True)
543
-
544
- print("Load base model")
545
- trainer.load_base_model()
546
-
547
-
548
- print("Start training")
549
- def update_progress():
550
- # Wait for the trainer to be initialized
551
- while trainer.transformer_trainer is None:
552
- time.sleep(0.5)
553
-
554
- time.sleep(1.5)
555
- # Update the progress bar until training is complete
556
- while trainer.transformer_trainer.state.global_step != trainer.transformer_trainer.state.max_steps:
557
- progress_bar = (
558
- trainer.transformer_trainer.state.global_step /
559
- trainer.transformer_trainer.state.max_steps
560
- )
561
- training_progress(progress_bar)
562
- time.sleep(0.5)
563
- training_progress(1.0)
564
-
565
- t1 = Thread(target=trainer.train)
566
- t1.start()
567
- t2 = Thread(target=update_progress())
568
- t2.start()
569
- t1.join()
570
- t2.join()
571
-
572
- yield "🎉 Model training complete! Please provide your task prompt."
573
-
574
- adapter_path = f"{config['model_saving']['model_output_dir']}/{config['model_saving']['model_name']}_{config['model_saving']['index']}_adapter_test_graph"
575
- peft_model = PeftModel.from_pretrained(model, adapter_path, torch_dtype=torch.float16)
576
-
577
- # change the global model with peft model
578
- model = peft_model
579
-
580
-
581
-
582
- if __name__ == "__main__":
583
- print("This is running in a virtual environment: {}".format(is_venv()))
584
-
585
- config = read_yaml_file("configs/config.yaml")
586
- template_file_path = 'configs/alpaca.json'
587
- template = json.load(open(template_file_path, "r"))
588
-
589
-
590
- seed_no = config['data_downloading']['processing']['random_seed']
591
- model_name = config['inference']['base_model']
592
- working_dir = config['data_downloading']['download_directory']
593
- save_zip_directory = working_dir + 'research_papers_zip/'
594
- save_directory = working_dir + 'research_papers/'
595
- save_description = working_dir + 'description/'
596
- save_path = save_description + 'results.json'
597
- save_graph = save_description + 'test_graph.json'
598
- gexf_file = save_description + config['data_downloading']['gexf_file']
599
- predef_graph = 'datasets/' + config['training']['predefined_graph_path']
600
- retrieval_nodes_path = 'datasets/retrieval_nodes.json'
601
-
602
- isExist = os.path.exists(save_zip_directory)
603
- if not isExist:
604
- os.makedirs(save_zip_directory)
605
- isExist = os.path.exists(save_directory)
606
- if not isExist:
607
- os.makedirs(save_directory)
608
- isExist = os.path.exists(save_description)
609
- if not isExist:
610
- os.makedirs(save_description)
611
-
612
-
613
- random.seed(seed_no)
614
-
615
-
616
- # Load model and tokenizer
617
- bnb_config = BitsAndBytesConfig(
618
- load_in_8bit=True,
619
- bnb_8bit_use_double_quant=True,
620
- bnb_8bit_quant_type="nf8",
621
- bnb_8bit_compute_dtype=torch.bfloat16
622
- )
623
- tokenizer = AutoTokenizer.from_pretrained(model_name)
624
- model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config=bnb_config)
625
- if model.device.type != 'cuda':
626
- model.to('cuda')
627
-
628
- if tokenizer.chat_template is not None:
629
- model_pipeline = pipeline(
630
- "text-generation",
631
- model=model_name,
632
- model_kwargs={"torch_dtype": torch.bfloat16},
633
- device_map="auto",
634
- )
635
-
636
- signal.signal(signal.SIGINT, signal_handler)
637
-
638
-
639
- # Global States for User Preferences
640
- download_papers = gr.State(value=True) # Default: Download papers
641
- train_model = gr.State(value=True) # Default: Train the model
642
-
643
-
644
- # Categorized Recommended Prompts
645
- task_list = {
646
- "Abstract Completion",
647
- "Introduction to Abstract",
648
- "Title Generation",
649
- "Citation Recommendation",
650
- "Citation Sentence Generation",
651
- "Citation Link Prediction",
652
- "Influential Papers Recommendation",
653
- "Related Work Generation",
654
- }
655
-
656
-
657
- # CSS for Styling
658
- css = """
659
- body { background-color: #E0F7FA; margin: 0; padding: 0; }
660
- .gradio-container { background-color: #E0F7FA; border-radius: 10px; }
661
- #logo-container { display: flex; justify-content: center; align-items: center; margin: 0 auto; padding: 0; max-width: 120px; height: 120px; border-radius: 10px; overflow: hidden; }
662
- #scroll-menu { max-height: 310px; overflow-y: auto; padding: 10px; background-color: #fff; margin-top: 10px;}
663
- #task-header { background-color: #0288d1; color: white; font-size: 18px; padding: 8px; text-align: center; margin-bottom: 5px; margin-top: 40px; }
664
- #category-header { background-color: #ecb939; font-size: 16px; padding: 8px; margin: 10px 0; }
665
- """
666
-
667
- # State to store the selected task
668
- selected_task = gr.State(value="")
669
-
670
-
671
- # Gradio Interface
672
- with gr.Blocks(theme="soft", css=css) as demo:
673
- gr.HTML('<div id="logo-container"><img src="https://static.thenounproject.com/png/6480915-200.png" alt="Logo"></div>')
674
- gr.Markdown("# LitBench Interface")
675
-
676
-
677
- # Setup row for user preferences
678
- with gr.Row(visible=True) as setup_row:
679
- with gr.Column():
680
- gr.Markdown("### Setup Your Preferences")
681
- download_option = gr.Dropdown(
682
- choices=["Download Paper", "Don't Download"],
683
- value="Download Paper",
684
- label="Download Option"
685
- )
686
- train_option = gr.Dropdown(
687
- choices=["Train", "Don't Train"],
688
- value="Train",
689
- label="Training Option"
690
- )
691
- setup_button = gr.Button("Set Preferences and Proceed")
692
-
693
-
694
- # Chatbot row for user interaction
695
- with gr.Row(visible=False) as chatbot_row:
696
- # Store the currently selected task
697
- with gr.Column(scale=3):
698
- gr.Markdown("### Start Chatting!")
699
- chatbot = gr.ChatInterface(
700
- predict,
701
- chatbot=gr.Chatbot(
702
- height=400,
703
- type="messages",
704
- avatar_images=[
705
- "https://icons.veryicon.com/png/o/miscellaneous/user-avatar/user-avatar-male-5.png",
706
- "https://cdn-icons-png.flaticon.com/512/8649/8649595.png"
707
- ],
708
- ),
709
- textbox=gr.Textbox(placeholder="Type your message here..."),
710
- additional_inputs=selected_task,
711
- additional_inputs_accordion=gr.Accordion(visible=False, label="Additional Inputs", ),
712
- )
713
-
714
- # Store user preferences and selected task for display
715
- preferences_output = gr.Textbox(value="", interactive=False, label="Your Preferences")
716
-
717
-
718
- # Task selection buttons for user interaction
719
- with gr.Column(scale=1):
720
- gr.HTML('<div id="task-header">Tasks:</div>')
721
- with gr.Column(elem_id="scroll-menu"):
722
- # Create buttons
723
- button_map = {prompt: gr.Button(prompt) for prompt in task_list}
724
-
725
- for prompt in task_list:
726
- button_map[prompt].click(
727
- toggle_selection,
728
- inputs=[selected_task, gr.State(value=prompt)], # Toggle task selection
729
- outputs=selected_task
730
- ).then(
731
- update_button_styles, # Update button appearances
732
- inputs=[selected_task],
733
- outputs=[button_map[p] for p in task_list] # Update all buttons
734
- )
735
-
736
-
737
- # Setup button to finalize user preferences and start chatbot
738
- setup_button.click(
739
- setup,
740
- inputs=[download_option, train_option],
741
- outputs=[setup_row, chatbot_row, preferences_output, chatbot.chatbot]
742
- )
743
-
744
-
745
- # Launch the interface
746
- demo.launch(server_port=7880)