aquibmoin commited on
Commit
8b8df4a
·
verified ·
1 Parent(s): 59f0172

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -204
app.py CHANGED
@@ -1,6 +1,3 @@
1
- # FC-RAG with FAISS
2
-
3
- # re-build: 02/03/2025
4
 
5
  import gradio as gr
6
  from transformers import AutoTokenizer, AutoModel
@@ -18,8 +15,7 @@ import tempfile
18
  from astroquery.nasa_ads import ADS
19
  import pyvo as vo
20
  import pandas as pd
21
- import faiss
22
- from PyPDF2 import PdfReader
23
 
24
  # Load the NASA-specific bi-encoder model and tokenizer
25
  bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
@@ -33,6 +29,12 @@ client = OpenAI(api_key=api_key)
33
  # Set up NASA ADS token
34
  ADS.TOKEN = os.getenv('ADS_API_KEY') # Ensure your ADS API key is stored in environment variables
35
 
 
 
 
 
 
 
36
  # Define system message with instructions
37
  system_message = """
38
  You are ExosAI, an advanced assistant specializing in Exoplanet and Astrophysics research.
@@ -73,85 +75,32 @@ Generate a **detailed and structured** response based on the given **science con
73
  Ensure the response is **structured, clear, and observation requirements table follows this format**. **All included parameters must be scientifically consistent with each other.**
74
  """
75
 
76
- def encode_text(text):
 
77
  inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
78
  outputs = bi_model(**inputs)
79
- return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten()
80
-
81
- def get_chunks(text, chunk_size=500):
82
- """
83
- Splits a long text into smaller chunks of approximately 'chunk_size' characters.
84
- Ensures that chunks do not cut off words abruptly.
85
- """
86
- if not text.strip():
87
- raise ValueError("The provided text is empty or blank.")
88
-
89
- # Split the text into chunks of approximately 'chunk_size' characters
90
- chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
91
-
92
- return chunks
93
-
94
- def load_and_process_uploaded_pdfs(pdf_files):
95
- """Extracts text from PDFs, splits into chunks, generates embeddings, and stores in FAISS."""
96
-
97
- # **RESET FAISS INDEX on every function call**
98
- embedding_dim = 768 # NASA Bi-Encoder embedding size
99
- index = faiss.IndexFlatIP(embedding_dim) # Fresh FAISS index
100
-
101
- pdf_chunks = [] # Store extracted chunks
102
- chunk_embeddings = [] # Store embeddings
103
-
104
- for pdf_file in pdf_files:
105
- reader = PdfReader(pdf_file)
106
- pdf_text = ""
107
- for page in reader.pages:
108
- pdf_text += page.extract_text() + "\n"
109
-
110
- # **Reduce Chunk Size for Faster Processing**
111
- chunks = get_chunks(pdf_text, chunk_size=300)
112
- pdf_chunks.extend(chunks) # Store for retrieval
113
-
114
- # Generate embeddings for each chunk
115
- for chunk in chunks:
116
- chunk_embedding = encode_text(chunk).reshape(1, -1)
117
-
118
- # Normalize for cosine similarity
119
- chunk_embedding = chunk_embedding / np.linalg.norm(chunk_embedding)
120
-
121
- index.add(chunk_embedding) # **Now adding to fresh FAISS index**
122
- chunk_embeddings.append(chunk_embedding)
123
 
124
- return index, pdf_chunks, chunk_embeddings # Return fresh FAISS index and chunk data
125
 
126
-
127
-
128
- def retrieve_relevant_context(user_input, context_text, science_objectives="", index=None, pdf_chunks=None, k=3):
129
- """
130
- Retrieve the most relevant document chunks using FAISS similarity search.
131
- Uses combined user inputs (Science Goal + Context + Optional Science Objectives).
132
- """
133
- if index is None or pdf_chunks is None:
134
- return "No indexed data available for retrieval."
135
-
136
- # Combine all user inputs into a single query
137
  query_text = f"Science Goal: {user_input}\nContext: {context_text}\nScience Objectives: {science_objectives}" if science_objectives else f"Science Goal: {user_input}\nContext: {context_text}"
 
138
 
139
- # Generate query embedding
140
- query_embedding = encode_text(query_text).reshape(1, -1)
141
-
142
- # Normalize the query embedding for cosine similarity
143
- query_embedding = query_embedding / np.linalg.norm(query_embedding)
144
-
145
- # Perform FAISS search to get top-k relevant chunks
146
- _, top_indices = index.search(query_embedding, k)
147
 
148
- # Retrieve the most relevant chunks using top indices
149
- retrieved_context = "\n\n".join([pdf_chunks[i] for i in top_indices[0]]) # FAISS returns indices in a nested list
150
 
151
- # If no relevant chunk is found, return a default message
152
  if not retrieved_context.strip():
153
  return "No relevant context found for the query."
154
-
155
  return retrieved_context
156
 
157
  def extract_keywords_with_gpt(user_input, max_tokens=100, temperature=0.3):
@@ -412,24 +361,19 @@ def gpt_response_to_dataframe(gpt_response):
412
  df = pd.DataFrame(rows, columns=headers)
413
  return df
414
 
415
- def chatbot(user_input, science_objectives="", context="", subdomain="", uploaded_pdfs=None, max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
416
- # Load and process uploaded PDFs (if provided)
417
- if uploaded_pdfs:
418
- index, pdf_chunks, chunk_embeddings = load_and_process_uploaded_pdfs(uploaded_pdfs)
419
- else:
420
- pdf_chunks, chunk_embeddings = [], [] # Ensure empty list if no PDFs provided
421
 
422
- # Retrieve relevant context using document search
423
- relevant_context = retrieve_relevant_context(user_input, context, science_objectives, index, pdf_chunks)
424
 
425
  # Fetch NASA ADS references using the full prompt
426
  references = fetch_nasa_ads_references(subdomain)
427
 
428
- # Generate response from GPT-4, ensuring we pass all relevant inputs
429
  response = generate_response(
430
  user_input=user_input,
431
  science_objectives=science_objectives,
432
- relevant_context=relevant_context, # Ensure retrieved FAISS context is passed
433
  references=references,
434
  max_tokens=max_tokens,
435
  temperature=temperature,
@@ -438,161 +382,64 @@ def chatbot(user_input, science_objectives="", context="", subdomain="", uploade
438
  presence_penalty=presence_penalty
439
  )
440
 
441
- # Append manually entered science objectives to the response (if provided)
442
  if science_objectives.strip():
443
  response = f"### Science Objectives (User-Defined):\n\n{science_objectives}\n\n" + response
444
-
445
- # Export the response to a Word document
446
  word_doc_path = export_to_word(
447
- response,
448
- subdomain,
449
- user_input,
450
- context,
451
- max_tokens,
452
- temperature,
453
- top_p,
454
- frequency_penalty,
455
- presence_penalty
456
  )
457
 
458
- # Fetch exoplanet data
459
  exoplanet_data = fetch_exoplanet_data()
460
-
461
- # Generate insights based on the user query and exoplanet data
462
  data_insights = generate_data_insights(user_input, exoplanet_data)
463
 
464
- # Extract and convert the table from the GPT-4 response into a DataFrame
465
  extracted_table_df = gpt_response_to_dataframe(response)
466
 
467
- # Combine the response and the data insights
468
  full_response = f"{response}\n\nEnd of Response"
469
-
470
- # Embed Miro iframe
471
- iframe_html = """
472
- <iframe width="768" height="432" src="https://miro.com/app/live-embed/uXjVKuVTcF8=/?moveToViewport=-331,-462,5434,3063&embedId=710273023721" frameborder="0" scrolling="no" allow="fullscreen; clipboard-read; clipboard-write" allowfullscreen></iframe>
473
- """
474
-
475
- mapify_button_html = """
476
- <style>
477
- .mapify-button {
478
- background: linear-gradient(135deg, #1E90FF 0%, #87CEFA 100%);
479
- border: none;
480
- color: white;
481
- padding: 15px 35px;
482
- text-align: center;
483
- text-decoration: none;
484
- display: inline-block;
485
- font-size: 18px;
486
- font-weight: bold;
487
- margin: 20px 2px;
488
- cursor: pointer;
489
- border-radius: 25px;
490
- transition: all 0.3s ease;
491
- box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2);
492
- }
493
- .mapify-button:hover {
494
- background: linear-gradient(135deg, #4682B4 0%, #1E90FF 100%);
495
- box-shadow: 0 6px 20px rgba(0, 0, 0, 0.3);
496
- transform: scale(1.05);
497
- }
498
- </style>
499
- <a href="https://mapify.so/app/new" target="_blank">
500
- <button class="mapify-button">Create Mind Map on Mapify</button>
501
- </a>
502
- """
503
  return full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html
504
 
505
  with gr.Blocks() as demo:
506
- gr.Markdown("# **ExosAI - NASA SMD FC-RAG SCDD Generator [version-1.1]**")
507
 
508
- # User Inputs
509
  gr.Markdown("## **User Inputs**")
510
  user_input = gr.Textbox(lines=5, placeholder="Enter your Science Goal...", label="Science Goal")
511
  context = gr.Textbox(lines=10, placeholder="Enter Context Text...", label="Additional Context")
512
  subdomain = gr.Textbox(lines=2, placeholder="Define your Subdomain...", label="Subdomain Definition")
513
 
514
- # PDF Upload Section (Up to 3 PDFs)
515
- gr.Markdown("### **Documents for Context Retrieval [e.g. LUVOIR, HabEx Reports]**")
516
- uploaded_pdfs = gr.Files(file_types=[".pdf"], label="Upload Reference PDFs (Up to 3)", interactive=True)
517
-
518
- # Science Objectives Button & Input (Initially Hidden)
519
  science_objectives_button = gr.Button("User-defined Science Objectives [Optional]")
520
- science_objectives_input = gr.Textbox(
521
- lines=5,
522
- placeholder="Enter Science Objectives...",
523
- label="Science Objectives",
524
- visible=False # Initially hidden
525
- )
526
 
527
- # Event to Show Science Objectives Input
528
- science_objectives_button.click(
529
- fn=lambda: gr.update(visible=True), # Show textbox when clicked
530
- inputs=[],
531
- outputs=[science_objectives_input]
532
- )
533
-
534
- # Additional Model Parameters
535
  gr.Markdown("### **Model Parameters**")
536
- max_tokens = gr.Slider(50, 2000, value=150, step=10, label="Max Tokens")
537
- temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="Temperature")
538
- top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.1, label="Top-p")
539
- frequency_penalty = gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Frequency Penalty")
540
- presence_penalty = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Presence Penalty")
541
 
542
- # Outputs
543
  gr.Markdown("## **Model Outputs**")
544
  full_response = gr.Textbox(label="ExosAI finds...")
545
  extracted_table_df = gr.Dataframe(label="SC Requirements Table")
546
- word_doc_path = gr.File(label="Download SCDD", type="filepath")
547
  iframe_html = gr.HTML(label="Miro")
548
  mapify_button_html = gr.HTML(label="Generate Mind Map on Mapify")
549
 
550
- # Buttons: Generate + Reset
551
  with gr.Row():
552
  submit_button = gr.Button("Generate SCDD")
553
  clear_button = gr.Button("Reset")
554
 
555
- # Define interaction: When "Generate SCDD" is clicked
556
- submit_button.click(
557
- fn=chatbot,
558
- inputs=[
559
- user_input, science_objectives_input, context, subdomain, uploaded_pdfs,
560
- max_tokens, temperature, top_p, frequency_penalty, presence_penalty
561
- ],
562
- outputs=[full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html]
563
- )
564
-
565
- # Define Clear Function (Ensuring the correct number of outputs)
566
- def clear_all():
567
- return (
568
- "", # user_input
569
- "", # science_objectives_input
570
- "", # context
571
- "", # subdomain
572
- None, # uploaded_pdfs
573
- 150, # max_tokens
574
- 0.7, # temperature
575
- 0.9, # top_p
576
- 0.5, # frequency_penalty
577
- 0.0, # presence_penalty
578
- "", # full_response (textbox output)
579
- None, # extracted_table_df (DataFrame output)
580
- None, # word_doc_path (File output)
581
- None, # iframe_html (HTML output)
582
- None # mapify_button_html (HTML output)
583
- )
584
 
585
- # Bind Clear Button (Ensuring the correct number of outputs)
586
- clear_button.click(
587
- fn=clear_all,
588
- inputs=[],
589
- outputs=[
590
- user_input, science_objectives_input, context, subdomain, uploaded_pdfs,
591
- max_tokens, temperature, top_p, frequency_penalty, presence_penalty,
592
- full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html
593
- ]
594
- )
595
 
596
- # Launch the app
597
  demo.launch(share=True)
598
 
 
 
 
 
1
 
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModel
 
15
  from astroquery.nasa_ads import ADS
16
  import pyvo as vo
17
  import pandas as pd
18
+ from pinecone import Pinecone
 
19
 
20
  # Load the NASA-specific bi-encoder model and tokenizer
21
  bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2"
 
29
  # Set up NASA ADS token
30
  ADS.TOKEN = os.getenv('ADS_API_KEY') # Ensure your ADS API key is stored in environment variables
31
 
32
+ # Pinecone setup
33
+ pinecone_api_key = os.getenv('PINECONE_API_KEY')
34
+ pc = Pinecone(api_key=pinecone_api_key)
35
+ index_name = "scdd-index"
36
+ index = pc.Index(index_name)
37
+
38
  # Define system message with instructions
39
  system_message = """
40
  You are ExosAI, an advanced assistant specializing in Exoplanet and Astrophysics research.
 
75
  Ensure the response is **structured, clear, and observation requirements table follows this format**. **All included parameters must be scientifically consistent with each other.**
76
  """
77
 
78
+ # Function to encode query text
79
+ def encode_query(text):
80
  inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128)
81
  outputs = bi_model(**inputs)
82
+ embedding = outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten()
83
+ embedding /= np.linalg.norm(embedding)
84
+ return embedding.tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
 
86
 
87
+ # Context retrieval function using Pinecone
88
+ def retrieve_relevant_context(user_input, context_text, science_objectives="", top_k=3):
 
 
 
 
 
 
 
 
 
89
  query_text = f"Science Goal: {user_input}\nContext: {context_text}\nScience Objectives: {science_objectives}" if science_objectives else f"Science Goal: {user_input}\nContext: {context_text}"
90
+ query_embedding = encode_query(query_text)
91
 
92
+ # Pinecone query
93
+ query_response = index.query(
94
+ vector=query_embedding,
95
+ top_k=top_k,
96
+ include_metadata=True
97
+ )
 
 
98
 
99
+ retrieved_context = "\n\n".join([match['metadata']['text'] for match in query_response.matches])
 
100
 
 
101
  if not retrieved_context.strip():
102
  return "No relevant context found for the query."
103
+
104
  return retrieved_context
105
 
106
  def extract_keywords_with_gpt(user_input, max_tokens=100, temperature=0.3):
 
361
  df = pd.DataFrame(rows, columns=headers)
362
  return df
363
 
364
+ def chatbot(user_input, science_objectives="", context="", subdomain="", max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0):
 
 
 
 
 
365
 
366
+ # Retrieve relevant context using Pinecone
367
+ relevant_context = retrieve_relevant_context(user_input, context, science_objectives)
368
 
369
  # Fetch NASA ADS references using the full prompt
370
  references = fetch_nasa_ads_references(subdomain)
371
 
372
+ # Generate response from GPT-4
373
  response = generate_response(
374
  user_input=user_input,
375
  science_objectives=science_objectives,
376
+ relevant_context=relevant_context,
377
  references=references,
378
  max_tokens=max_tokens,
379
  temperature=temperature,
 
382
  presence_penalty=presence_penalty
383
  )
384
 
385
+ # Append user-defined science objectives if provided
386
  if science_objectives.strip():
387
  response = f"### Science Objectives (User-Defined):\n\n{science_objectives}\n\n" + response
388
+
389
+ # Export response to Word
390
  word_doc_path = export_to_word(
391
+ response, subdomain, user_input, context,
392
+ max_tokens, temperature, top_p, frequency_penalty, presence_penalty
 
 
 
 
 
 
 
393
  )
394
 
395
+ # Fetch exoplanet data and generate insights
396
  exoplanet_data = fetch_exoplanet_data()
 
 
397
  data_insights = generate_data_insights(user_input, exoplanet_data)
398
 
399
+ # Extract GPT-generated table into DataFrame
400
  extracted_table_df = gpt_response_to_dataframe(response)
401
 
402
+ # Combine response and insights
403
  full_response = f"{response}\n\nEnd of Response"
404
+
405
+ iframe_html = """<iframe width=\"768\" height=\"432\" src=\"https://miro.com/app/live-embed/uXjVKuVTcF8=/?moveToViewport=-331,-462,5434,3063&embedId=710273023721\" frameborder=\"0\" scrolling=\"no\" allow=\"fullscreen; clipboard-read; clipboard-write\" allowfullscreen></iframe>"""
406
+ mapify_button_html = """<a href=\"https://mapify.so/app/new\" target=\"_blank\"><button>Create Mind Map on Mapify</button></a>"""
407
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  return full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html
409
 
410
  with gr.Blocks() as demo:
411
+ gr.Markdown("# **ExosAI - NASA SMD FCRAG SCDD Generator [version-2.1]**")
412
 
 
413
  gr.Markdown("## **User Inputs**")
414
  user_input = gr.Textbox(lines=5, placeholder="Enter your Science Goal...", label="Science Goal")
415
  context = gr.Textbox(lines=10, placeholder="Enter Context Text...", label="Additional Context")
416
  subdomain = gr.Textbox(lines=2, placeholder="Define your Subdomain...", label="Subdomain Definition")
417
 
 
 
 
 
 
418
  science_objectives_button = gr.Button("User-defined Science Objectives [Optional]")
419
+ science_objectives_input = gr.Textbox(lines=5, placeholder="Enter Science Objectives...", label="Science Objectives", visible=False)
420
+ science_objectives_button.click(lambda: gr.update(visible=True), outputs=[science_objectives_input])
 
 
 
 
421
 
 
 
 
 
 
 
 
 
422
  gr.Markdown("### **Model Parameters**")
423
+ max_tokens = gr.Slider(50, 2000, 150, step=10, label="Max Tokens")
424
+ temperature = gr.Slider(0.0, 1.0, 0.7, step=0.1, label="Temperature")
425
+ top_p = gr.Slider(0.0, 1.0, 0.9, step=0.1, label="Top-p")
426
+ frequency_penalty = gr.Slider(0.0, 1.0, 0.5, step=0.1, label="Frequency Penalty")
427
+ presence_penalty = gr.Slider(0.0, 1.0, 0.0, step=0.1, label="Presence Penalty")
428
 
 
429
  gr.Markdown("## **Model Outputs**")
430
  full_response = gr.Textbox(label="ExosAI finds...")
431
  extracted_table_df = gr.Dataframe(label="SC Requirements Table")
432
+ word_doc_path = gr.File(label="Download SCDD")
433
  iframe_html = gr.HTML(label="Miro")
434
  mapify_button_html = gr.HTML(label="Generate Mind Map on Mapify")
435
 
 
436
  with gr.Row():
437
  submit_button = gr.Button("Generate SCDD")
438
  clear_button = gr.Button("Reset")
439
 
440
+ submit_button.click(chatbot, inputs=[user_input, science_objectives_input, context, subdomain, max_tokens, temperature, top_p, frequency_penalty, presence_penalty], outputs=[full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
+ clear_button.click(lambda: ("", "", "", "", 150, 0.7, 0.9, 0.5, 0.0, "", None, None, None, None), outputs=[user_input, science_objectives_input, context, subdomain, max_tokens, temperature, top_p, frequency_penalty, presence_penalty, full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html])
 
 
 
 
 
 
 
 
 
443
 
 
444
  demo.launch(share=True)
445