Spaces:
Running
Running
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModel | |
| from openai import OpenAI | |
| import os | |
| import numpy as np | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from docx import Document | |
| from docx.shared import Pt | |
| from docx.enum.text import WD_PARAGRAPH_ALIGNMENT | |
| from docx.oxml.ns import nsdecls | |
| from docx.oxml import parse_xml | |
| import io | |
| import tempfile | |
| from astroquery.nasa_ads import ADS | |
| import pyvo as vo | |
| import pandas as pd | |
| # Load the NASA-specific bi-encoder model and tokenizer | |
| bi_encoder_model_name = "nasa-impact/nasa-smd-ibm-st-v2" | |
| bi_tokenizer = AutoTokenizer.from_pretrained(bi_encoder_model_name) | |
| bi_model = AutoModel.from_pretrained(bi_encoder_model_name) | |
| # Set up OpenAI client | |
| api_key = os.getenv('OPENAI_API_KEY') | |
| client = OpenAI(api_key=api_key) | |
| # Set up NASA ADS token | |
| ADS.TOKEN = os.getenv('ADS_API_KEY') # Ensure your ADS API key is stored in environment variables | |
| # Define system message with instructions | |
| system_message = """ | |
| You are ExosAI, an advanced assistant specializing in Exoplanet and Astrophysics research. | |
| Generate a **detailed and structured** response based on the given **science context and user input**, incorporating key **observables, physical parameters, and technical requirements**. Organize the response into the following sections: | |
| 1. **Science Objectives**: Define key scientific objectives related to the science context and user input. | |
| 2. **Physical Parameters**: Outline the relevant physical parameters (e.g., mass, temperature, composition). | |
| 3. **Observables**: Specify the key observables required to study the science context. | |
| 4. **Description of Desired Observations**: Detail the observational techniques, instruments, or approaches necessary to gather relevant data. | |
| 5. **Observations Requirements Table**: Generate a table relevant to the Science Objectives, Physical Parameters, Observables and Description of Desired Observations with the following columns and at least 7 rows: | |
| - Wavelength Band: Should only be UV, Visible and Infrared). | |
| - Instrument: Should only be Imager, Spectrograph, Polarimeter and Coronagraph). | |
| - Necessary Values: The necessary values or parameters (wavelength range, spectral resolution where applicable, spatial resolution where applicable, contrast ratio where applicable). | |
| - Desired Values: The desired values or parameters (wavelength range, spectral resolution where applicable, spatial resolution where applicable). | |
| - Number of Objects Observed: Estimate the number of objects that need to be observed for a statistically meaningful result or for fulfilling the science objective. | |
| - Justification: Detailed scientific explanation of why these observations are important for the science objectives. | |
| - Comments: Additional notes or remarks regarding each observation. | |
| #### **Table Format** | |
| | Wavelength Band | Instrument | Necessary Values | Desired Values | Number of Objects Observed | Justification | Comments | | |
| |----------------------|------------------------------------|------------------------------------|---------------------------------|---------------------------------|-------------------|----------| | |
| #### **Guiding Constraints (Exclusions & Prioritization)** | |
| - **Wavelength Band Restriction:** Only include **UV, Visible, and Infrared** bands. | |
| - **Instrument Restriction:** Only include **Imager, Spectrograph, Polarimeter, and Coronagraph**. | |
| - **Wavelength Limits:** Prioritize wavelengths between **100 nanometers (nm) and 3 micrometers (μm)**. | |
| - **Allowed Instruments:** **Only include** observations from **direct imaging, spectroscopy, and polarimetry.** **Exclude** transit and radial velocity methods. | |
| - **Exclusion of Existing Facilities:** **Do not reference** existing observatories such as JWST, Hubble, or ground-based telescopes. This work pertains to a **new mission**. | |
| - **Spectral Resolution Constraint:** Limit spectral resolution (**R**) to the range **10,000 – 50,000**. | |
| - **Contrast Ratio:** Limit contrast ratio to the range **10^4 - 10^6**. | |
| - **Estimate the "Number of Objects Observed" based on the observational strategy, statistical requirements, and feasibility.** | |
| - **Ensure that all parameters remain scientifically consistent.** | |
| **Use this table format as a guideline, generate a detailed table dynamically based on the input.**. Ensure that all values align with the provided constraints and instructions. | |
| Ensure the response is **structured, clear, and observation requirements table follows this format**. **All included parameters must be scientifically consistent with each other.** | |
| """ | |
| def encode_text(text): | |
| inputs = bi_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=128) | |
| outputs = bi_model(**inputs) | |
| return outputs.last_hidden_state.mean(dim=1).detach().numpy().flatten() | |
| def get_chunks(text, chunk_size=300): | |
| """ | |
| Split a long piece of text into smaller chunks of approximately 'chunk_size' characters. | |
| """ | |
| if not text.strip(): | |
| raise ValueError("The provided context is empty or blank.") | |
| # Split the text into chunks of approximately 'chunk_size' characters | |
| chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)] | |
| return chunks | |
| def retrieve_relevant_context(user_input, context_texts, chunk_size=300, similarity_threshold=0.3): | |
| """ | |
| Split the context text into smaller chunks, find the most relevant chunk | |
| using cosine similarity, and return the most relevant chunk. | |
| If no chunk meets the similarity threshold, return a fallback message. | |
| """ | |
| # Check if the context is empty or just whitespace | |
| if not context_texts.strip(): | |
| return "Error: Context is empty or improperly formatted.", None | |
| # Split the long context text into chunks using the chunking function | |
| context_chunks = get_chunks(context_texts, chunk_size) | |
| # Handle single context case | |
| if len(context_chunks) == 1: | |
| return context_chunks[0], 1.0 # Return the single chunk with perfect similarity | |
| # Encode the user input to create a query embedding | |
| user_embedding = encode_text(user_input).reshape(1, -1) | |
| # Encode all context chunks to create embeddings | |
| chunk_embeddings = np.array([encode_text(chunk) for chunk in context_chunks]) | |
| # Compute cosine similarity between the user input and each chunk | |
| similarities = cosine_similarity(user_embedding, chunk_embeddings).flatten() | |
| # Check if any similarity scores are above the threshold | |
| if max(similarities) < similarity_threshold: | |
| return "No relevant context found for the user input.", None | |
| # Identify the most relevant chunk based on the highest cosine similarity score | |
| most_relevant_idx = np.argmax(similarities) | |
| most_relevant_chunk = context_chunks[most_relevant_idx] | |
| # Return the most relevant chunk and the similarity score | |
| return most_relevant_chunk | |
| def extract_keywords_with_gpt(user_input, max_tokens=100, temperature=0.3): | |
| # Define a prompt to ask GPT-4 to extract keywords and important terms | |
| keyword_prompt = f"Extract the most important keywords, scientific concepts, and parameters from the following user query:\n\n{user_input}" | |
| # Call GPT-4 to extract keywords based on the user prompt | |
| response = client.chat.completions.create( | |
| model="gpt-4", | |
| messages=[ | |
| {"role": "system", "content": "You are an expert in identifying key scientific terms and concepts."}, | |
| {"role": "user", "content": keyword_prompt} | |
| ], | |
| max_tokens=max_tokens, | |
| temperature=temperature | |
| ) | |
| # Extract the content from GPT-4's reply | |
| extracted_keywords = response.choices[0].message.content.strip() | |
| return extracted_keywords | |
| def fetch_nasa_ads_references(prompt): | |
| try: | |
| # Use the entire prompt for the query | |
| simplified_query = prompt | |
| # Query NASA ADS for relevant papers | |
| papers = ADS.query_simple(simplified_query) | |
| if not papers or len(papers) == 0: | |
| return [("No results found", "N/A", "N/A")] | |
| # Include authors in the references | |
| references = [ | |
| ( | |
| paper['title'][0], | |
| ", ".join(paper['author'][:3]) + (" et al." if len(paper['author']) > 3 else ""), | |
| paper['bibcode'] | |
| ) | |
| for paper in papers[:5] # Limit to 5 references | |
| ] | |
| return references | |
| except Exception as e: | |
| return [("Error fetching references", str(e), "N/A")] | |
| def fetch_exoplanet_data(): | |
| # Connect to NASA Exoplanet Archive TAP Service | |
| tap_service = vo.dal.TAPService("https://exoplanetarchive.ipac.caltech.edu/TAP") | |
| # Query to fetch all columns from the pscomppars table | |
| ex_query = """ | |
| SELECT TOP 10 pl_name, hostname, sy_snum, sy_pnum, discoverymethod, disc_year, disc_facility, pl_controv_flag, pl_orbper, pl_orbsmax, pl_rade, pl_bmasse, pl_orbeccen, pl_eqt, st_spectype, st_teff, st_rad, st_mass, ra, dec, sy_vmag | |
| FROM pscomppars | |
| """ | |
| # Execute the query | |
| qresult = tap_service.search(ex_query) | |
| # Convert to a Pandas DataFrame | |
| ptable = qresult.to_table() | |
| exoplanet_data = ptable.to_pandas() | |
| return exoplanet_data | |
| def generate_response(user_input, science_objectives="", relevant_context="", references=[], max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0): | |
| # Case 1: Both relevant context and science objectives are provided | |
| if relevant_context and science_objectives.strip(): | |
| combined_input = f"Scientific Context: {relevant_context}\nUser Input: {user_input}\nScience Objectives (User Provided): {science_objectives}\n\nPlease generate only the remaining sections as per the defined format." | |
| # Case 2: Only relevant context is provided | |
| elif relevant_context: | |
| combined_input = f"Scientific Context: {relevant_context}\nUser Input: {user_input}\n\nPlease generate a full structured response, including Science Objectives." | |
| # Case 3: Neither context nor science objectives are provided | |
| elif science_objectives.strip(): | |
| combined_input = f"User Input: {user_input}\nScience Objectives (User Provided): {science_objectives}\n\nPlease generate only the remaining sections as per the defined format." | |
| # Default: No relevant context or science objectives → Generate everything | |
| else: | |
| combined_input = f"User Input: {user_input}\n\nPlease generate a full structured response, including Science Objectives." | |
| response = client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=[ | |
| {"role": "system", "content": system_message}, | |
| {"role": "user", "content": combined_input} | |
| ], | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| frequency_penalty=frequency_penalty, | |
| presence_penalty=presence_penalty | |
| ) | |
| # Append references to the response | |
| if references: | |
| response_content = response.choices[0].message.content.strip() | |
| references_text = "\n\nADS References:\n" + "\n".join( | |
| [f"- {title} by {authors} (Bibcode: {bibcode})" for title, authors, bibcode in references] | |
| ) | |
| return f"{response_content}\n{references_text}" | |
| return response.choices[0].message.content.strip() | |
| def generate_data_insights(user_input, exoplanet_data, max_tokens=500, temperature=0.3): | |
| """ | |
| Generate insights by passing the user's input along with the exoplanet data to GPT-4. | |
| """ | |
| # Convert the dataframe to a readable format for GPT (e.g., CSV-style text) | |
| data_as_text = exoplanet_data.to_csv(index=False) # CSV-style for better readability | |
| # Create a prompt with the user query and the data sample | |
| insights_prompt = ( | |
| f"Analyze the following user query and provide relevant insights based on the provided exoplanet data.\n\n" | |
| f"User Query: {user_input}\n\n" | |
| f"Exoplanet Data:\n{data_as_text}\n\n" | |
| f"Please provide insights that are relevant to the user's query." | |
| ) | |
| # Call GPT-4 to generate insights based on the data and user input | |
| response = client.chat.completions.create( | |
| model="gpt-4", | |
| messages=[ | |
| {"role": "system", "content": "You are an expert in analyzing astronomical data and generating insights."}, | |
| {"role": "user", "content": insights_prompt} | |
| ], | |
| max_tokens=max_tokens, | |
| temperature=temperature | |
| ) | |
| # Extract and return GPT-4's insights | |
| data_insights = response.choices[0].message.content.strip() | |
| return data_insights | |
| def export_to_word(response_content, subdomain_definition, science_goal): | |
| doc = Document() | |
| # Add a title (optional, you can remove this if not needed) | |
| doc.add_heading('AI Generated SCDD', 0) | |
| # Insert the Subdomain Definition at the top | |
| doc.add_heading('Subdomain Definition:', level=1) | |
| doc.add_paragraph(subdomain_definition) | |
| # Insert the Science Goal at the top | |
| doc.add_heading('Science Goal:', level=1) | |
| doc.add_paragraph(science_goal) | |
| # Split the response into sections based on ### headings | |
| sections = response_content.split('### ') | |
| for section in sections: | |
| if section.strip(): | |
| # Handle the "Observations Requirements Table" separately with proper formatting | |
| if section.startswith('Observations Requirements Table'): | |
| doc.add_heading('Observations Requirements Table', level=1) | |
| # Extract table lines | |
| table_lines = section.split('\n')[2:] # Start after the heading line | |
| # Check if it's an actual table (split lines by '|' symbol) | |
| table_data = [line.split('|')[1:-1] for line in table_lines if '|' in line] | |
| if table_data: | |
| # Add table to the document | |
| table = doc.add_table(rows=len(table_data), cols=len(table_data[0])) | |
| table.style = 'Table Grid' | |
| for i, row in enumerate(table_data): | |
| for j, cell_text in enumerate(row): | |
| cell = table.cell(i, j) | |
| cell.text = cell_text.strip() | |
| # Apply text wrapping for each cell | |
| cell._element.get_or_add_tcPr().append(parse_xml(r'<w:tcW w:w="2500" w:type="pct" ' + nsdecls('w') + '/>')) | |
| # Process any paragraphs that follow the table | |
| paragraph_after_table = '\n'.join([line for line in table_lines if '|' not in line and line.strip()]) | |
| if paragraph_after_table: | |
| doc.add_paragraph(paragraph_after_table.strip()) | |
| # Handle the "ADS References" section | |
| elif section.startswith('ADS References'): | |
| doc.add_heading('ADS References', level=1) | |
| references = section.split('\n')[1:] # Skip the heading | |
| for reference in references: | |
| if reference.strip(): | |
| doc.add_paragraph(reference.strip()) | |
| # Add all other sections as plain paragraphs | |
| else: | |
| doc.add_paragraph(section.strip()) | |
| # Save the document to a temporary file | |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".docx") | |
| doc.save(temp_file.name) | |
| return temp_file.name | |
| def extract_table_from_response(gpt_response): | |
| # Split the response into lines | |
| lines = gpt_response.strip().split("\n") | |
| # Find where the table starts and ends (based on the presence of pipes `|` and at least 3 columns) | |
| table_lines = [line for line in lines if '|' in line and len(line.split('|')) > 3] | |
| # If no table is found, return None or an empty string | |
| if not table_lines: | |
| return None | |
| # Find the first and last index of the table lines | |
| first_table_index = lines.index(table_lines[0]) | |
| last_table_index = lines.index(table_lines[-1]) | |
| # Extract only the table part | |
| table_text = lines[first_table_index:last_table_index + 1] | |
| return table_text | |
| def gpt_response_to_dataframe(gpt_response): | |
| # Extract the table text from the GPT response | |
| table_lines = extract_table_from_response(gpt_response) | |
| # If no table found, return an empty DataFrame | |
| if table_lines is None or len(table_lines) == 0: | |
| return pd.DataFrame() | |
| # Find the header and row separator (assume it's a line with dashes like |---|) | |
| try: | |
| # The separator line (contains dashes separating headers and rows) | |
| sep_line_index = next(i for i, line in enumerate(table_lines) if set(line.strip()) == {'|', '-'}) | |
| except StopIteration: | |
| # If no separator line is found, return an empty DataFrame | |
| return pd.DataFrame() | |
| # Extract headers (the line before the separator) and rows (lines after the separator) | |
| headers = [h.strip() for h in table_lines[sep_line_index - 1].split('|')[1:-1]] | |
| # Extract rows (each line after the separator) | |
| rows = [ | |
| [cell.strip() for cell in row.split('|')[1:-1]] | |
| for row in table_lines[sep_line_index + 1:] | |
| ] | |
| # Create DataFrame | |
| df = pd.DataFrame(rows, columns=headers) | |
| return df | |
| def chatbot(user_input, science_objectives="", context="", subdomain="", use_encoder=False, max_tokens=150, temperature=0.7, top_p=0.9, frequency_penalty=0.5, presence_penalty=0.0): | |
| if use_encoder and context: | |
| context_texts = context | |
| relevant_context = retrieve_relevant_context(user_input, context_texts) | |
| else: | |
| relevant_context = "" | |
| # Fetch NASA ADS references using the full prompt | |
| references = fetch_nasa_ads_references(subdomain) | |
| # Generate response from GPT-4 | |
| response = generate_response( | |
| user_input=user_input, | |
| science_objectives=science_objectives, # Pass Science Objectives | |
| relevant_context=relevant_context, # Pass retrieved context (if any) | |
| references=references, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| frequency_penalty=frequency_penalty, | |
| presence_penalty=presence_penalty | |
| ) | |
| if science_objectives.strip(): | |
| response = f"### Science Objectives (User-Defined):\n\n{science_objectives}\n\n" + response | |
| # Export the response to a Word document | |
| word_doc_path = export_to_word(response, subdomain, user_input) | |
| # Fetch exoplanet data | |
| exoplanet_data = fetch_exoplanet_data() | |
| # Generate insights based on the user query and exoplanet data | |
| data_insights = generate_data_insights(user_input, exoplanet_data) | |
| # Extract and convert the table from the GPT-4 response into a DataFrame | |
| extracted_table_df = gpt_response_to_dataframe(response) | |
| # Combine the response and the data insights | |
| full_response = f"{response}\n\nEnd of Response" | |
| # Embed Miro iframe | |
| iframe_html = """ | |
| <iframe width="768" height="432" src="https://miro.com/app/live-embed/uXjVKuVTcF8=/?moveToViewport=-331,-462,5434,3063&embedId=710273023721" frameborder="0" scrolling="no" allow="fullscreen; clipboard-read; clipboard-write" allowfullscreen></iframe> | |
| """ | |
| mapify_button_html = """ | |
| <style> | |
| .mapify-button { | |
| background: linear-gradient(135deg, #1E90FF 0%, #87CEFA 100%); | |
| border: none; | |
| color: white; | |
| padding: 15px 35px; | |
| text-align: center; | |
| text-decoration: none; | |
| display: inline-block; | |
| font-size: 18px; | |
| font-weight: bold; | |
| margin: 20px 2px; | |
| cursor: pointer; | |
| border-radius: 25px; | |
| transition: all 0.3s ease; | |
| box-shadow: 0 4px 15px rgba(0, 0, 0, 0.2); | |
| } | |
| .mapify-button:hover { | |
| background: linear-gradient(135deg, #4682B4 0%, #1E90FF 100%); | |
| box-shadow: 0 6px 20px rgba(0, 0, 0, 0.3); | |
| transform: scale(1.05); | |
| } | |
| </style> | |
| <a href="https://mapify.so/app/new" target="_blank"> | |
| <button class="mapify-button">Create Mind Map on Mapify</button> | |
| </a> | |
| """ | |
| return full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# ExosAI - NASA SMD SCDD AI Assistant [version-0.91a]") | |
| # User Inputs | |
| user_input = gr.Textbox(lines=5, placeholder="Enter your Science Goal...", label="Science Goal") | |
| context = gr.Textbox(lines=10, placeholder="Enter Context Text...", label="Context") | |
| subdomain = gr.Textbox(lines=2, placeholder="Define your Subdomain...", label="Subdomain Definition") | |
| # Science Objectives Button & Input (Initially Hidden) | |
| science_objectives_button = gr.Button("Manually Enter Science Objectives") | |
| science_objectives_input = gr.Textbox( | |
| lines=5, | |
| placeholder="Enter Science Objectives...", | |
| label="Science Objectives", | |
| visible=False # Initially hidden | |
| ) | |
| # Define event inside Blocks (Fix for the Error) | |
| science_objectives_button.click( | |
| fn=lambda: gr.update(visible=True), # Show textbox when clicked | |
| inputs=[], | |
| outputs=[science_objectives_input] | |
| ) | |
| # More Inputs | |
| use_encoder = gr.Checkbox(label="Use NASA SMD Bi-Encoder for Context") | |
| max_tokens = gr.Slider(50, 2000, value=150, step=10, label="Max Tokens") | |
| temperature = gr.Slider(0.0, 1.0, value=0.7, step=0.1, label="Temperature") | |
| top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.1, label="Top-p") | |
| frequency_penalty = gr.Slider(0.0, 1.0, value=0.5, step=0.1, label="Frequency Penalty") | |
| presence_penalty = gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="Presence Penalty") | |
| # Outputs | |
| full_response = gr.Textbox(label="ExosAI finds...") | |
| extracted_table_df = gr.Dataframe(label="SC Requirements Table") | |
| word_doc_path = gr.File(label="Download SCDD", type="filepath") | |
| iframe_html = gr.HTML(label="Miro") | |
| mapify_button_html = gr.HTML(label="Generate Mind Map on Mapify") | |
| # Buttons: Generate + Reset | |
| with gr.Row(): | |
| submit_button = gr.Button("Generate SCDD") | |
| clear_button = gr.Button("Reset") | |
| # Define interaction: When "Generate SCDD" is clicked | |
| submit_button.click( | |
| fn=chatbot, | |
| inputs=[ | |
| user_input, science_objectives_input, context, subdomain, | |
| use_encoder, max_tokens, temperature, top_p, frequency_penalty, presence_penalty | |
| ], | |
| outputs=[full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html] | |
| ) | |
| # Define Clear Function (Ensuring the correct number of outputs) | |
| def clear_all(): | |
| return ( | |
| "", # user_input | |
| "", # science_objectives_input | |
| "", # context | |
| "", # subdomain | |
| False, # use_encoder | |
| 150, # max_tokens | |
| 0.7, # temperature | |
| 0.9, # top_p | |
| 0.5, # frequency_penalty | |
| 0.0, # presence_penalty | |
| "", # full_response (textbox output) | |
| None, # extracted_table_df (DataFrame output) | |
| None, # word_doc_path (File output) | |
| None, # iframe_html (HTML output) | |
| None # mapify_button_html (HTML output) | |
| ) | |
| # Bind Clear Button (Ensuring the correct number of outputs) | |
| clear_button.click( | |
| fn=clear_all, | |
| inputs=[], | |
| outputs=[ | |
| user_input, science_objectives_input, context, subdomain, | |
| use_encoder, max_tokens, temperature, top_p, frequency_penalty, presence_penalty, | |
| full_response, extracted_table_df, word_doc_path, iframe_html, mapify_button_html | |
| ] | |
| ) | |
| # Launch the app | |
| demo.launch(share=True) | |