Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
# Step 1: Import Libraries
|
| 2 |
-
# No changes needed here, these will be installed from your requirements.txt
|
| 3 |
import gradio as gr
|
| 4 |
from sentence_transformers import SentenceTransformer
|
| 5 |
import faiss
|
|
@@ -9,9 +8,7 @@ import pickle
|
|
| 9 |
import os
|
| 10 |
|
| 11 |
# Step 2: Configure API Key from Hugging Face Secrets
|
| 12 |
-
# This section is modified to securely access the API key from your Space's secrets.
|
| 13 |
try:
|
| 14 |
-
# Make sure to set 'GOOGLE_API_KEY' in your Hugging Face Space's settings/secrets
|
| 15 |
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
|
| 16 |
if GOOGLE_API_KEY is None:
|
| 17 |
raise ValueError("GOOGLE_API_KEY not found in environment variables.")
|
|
@@ -21,18 +18,15 @@ except Exception as e:
|
|
| 21 |
print(f"ERROR: Could not configure API key. Please ensure 'GOOGLE_API_KEY' is set in your Hugging Face Space secrets. Details: {e}")
|
| 22 |
|
| 23 |
# Step 3: Define Data Path
|
| 24 |
-
# This is updated to point to the 'rag_chatbot_data' folder you uploaded to the Space.
|
| 25 |
-
# It no longer uses Google Drive.
|
| 26 |
DATA_PATH = "data"
|
| 27 |
vector_store_file = os.path.join(DATA_PATH, "vector_store.index")
|
| 28 |
data_file = os.path.join(DATA_PATH, "data.pkl")
|
| 29 |
|
| 30 |
# Step 4: Load Models and Pre-processed Data
|
| 31 |
-
# This section is slightly restructured to load models only if the data is found.
|
| 32 |
vector_store_data = None
|
| 33 |
if os.path.exists(vector_store_file) and os.path.exists(data_file):
|
| 34 |
try:
|
| 35 |
-
print("Loading pre-processed data from the
|
| 36 |
index = faiss.read_index(vector_store_file)
|
| 37 |
with open(data_file, "rb") as f:
|
| 38 |
stored_data = pickle.load(f)
|
|
@@ -40,30 +34,26 @@ if os.path.exists(vector_store_file) and os.path.exists(data_file):
|
|
| 40 |
sources = stored_data["sources"]
|
| 41 |
print("β
Data loaded successfully.")
|
| 42 |
|
| 43 |
-
# Store everything in a state object for Gradio
|
| 44 |
vector_store_data = (index, texts, sources)
|
| 45 |
|
| 46 |
-
# Load the embedding and LLM models
|
| 47 |
print("Loading AI and embedding models...")
|
| 48 |
embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5')
|
| 49 |
-
llm = genai.GenerativeModel('gemini-
|
| 50 |
print("β
Models loaded successfully.")
|
| 51 |
|
| 52 |
except Exception as e:
|
| 53 |
print(f"β ERROR: An error occurred during data or model loading: {e}")
|
| 54 |
vector_store_data = None
|
| 55 |
-
|
| 56 |
else:
|
| 57 |
-
print("β ERROR: Pre-processed data not found in the
|
| 58 |
-
print(f"Please make sure '{vector_store_file}' and '{data_file}' exist
|
| 59 |
|
| 60 |
|
| 61 |
-
# Step 5: RAG and Chat Functions (
|
| 62 |
def get_relevant_context(query, index, top_k=5):
|
| 63 |
query_embedding = embedding_model.encode([query])
|
| 64 |
distances, indices = index.search(query_embedding, top_k)
|
| 65 |
context = []
|
| 66 |
-
# Ensure that the indices are within the bounds of the texts list
|
| 67 |
for i in indices[0]:
|
| 68 |
if i < len(texts):
|
| 69 |
context.append({"text": texts[i], "source": sources[i]})
|
|
@@ -74,14 +64,17 @@ def chat_with_rag(message, history, vector_store_data):
|
|
| 74 |
relevant_context = get_relevant_context(message, index)
|
| 75 |
context_str = "\n\n".join([f"Source: {c['source']}\nContent: {c['text']}" for c in relevant_context])
|
| 76 |
|
|
|
|
| 77 |
prompt = f"""
|
| 78 |
-
You are
|
| 79 |
-
Your
|
| 80 |
|
| 81 |
-
Follow these rules strictly
|
| 82 |
-
1.
|
| 83 |
-
2.
|
| 84 |
-
3.
|
|
|
|
|
|
|
| 85 |
|
| 86 |
Context from the Halassa Lab's papers:
|
| 87 |
---
|
|
@@ -90,7 +83,7 @@ def chat_with_rag(message, history, vector_store_data):
|
|
| 90 |
|
| 91 |
User Question: {message}
|
| 92 |
|
| 93 |
-
|
| 94 |
"""
|
| 95 |
try:
|
| 96 |
response = llm.generate_content(prompt)
|
|
@@ -98,27 +91,70 @@ def chat_with_rag(message, history, vector_store_data):
|
|
| 98 |
except Exception as e:
|
| 99 |
return f"An error occurred with the AI model: {str(e)}"
|
| 100 |
|
| 101 |
-
# Step 6: Gradio User Interface
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
# Step 7: Launch the app
|
| 121 |
-
# On Hugging Face, demo.launch() is all you need.
|
| 122 |
-
# It will run the app and make it accessible.
|
| 123 |
if vector_store_data:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
demo.launch()
|
|
|
|
| 1 |
# Step 1: Import Libraries
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
from sentence_transformers import SentenceTransformer
|
| 4 |
import faiss
|
|
|
|
| 8 |
import os
|
| 9 |
|
| 10 |
# Step 2: Configure API Key from Hugging Face Secrets
|
|
|
|
| 11 |
try:
|
|
|
|
| 12 |
GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
|
| 13 |
if GOOGLE_API_KEY is None:
|
| 14 |
raise ValueError("GOOGLE_API_KEY not found in environment variables.")
|
|
|
|
| 18 |
print(f"ERROR: Could not configure API key. Please ensure 'GOOGLE_API_KEY' is set in your Hugging Face Space secrets. Details: {e}")
|
| 19 |
|
| 20 |
# Step 3: Define Data Path
|
|
|
|
|
|
|
| 21 |
DATA_PATH = "data"
|
| 22 |
vector_store_file = os.path.join(DATA_PATH, "vector_store.index")
|
| 23 |
data_file = os.path.join(DATA_PATH, "data.pkl")
|
| 24 |
|
| 25 |
# Step 4: Load Models and Pre-processed Data
|
|
|
|
| 26 |
vector_store_data = None
|
| 27 |
if os.path.exists(vector_store_file) and os.path.exists(data_file):
|
| 28 |
try:
|
| 29 |
+
print(f"Loading pre-processed data from the '{DATA_PATH}' directory...")
|
| 30 |
index = faiss.read_index(vector_store_file)
|
| 31 |
with open(data_file, "rb") as f:
|
| 32 |
stored_data = pickle.load(f)
|
|
|
|
| 34 |
sources = stored_data["sources"]
|
| 35 |
print("β
Data loaded successfully.")
|
| 36 |
|
|
|
|
| 37 |
vector_store_data = (index, texts, sources)
|
| 38 |
|
|
|
|
| 39 |
print("Loading AI and embedding models...")
|
| 40 |
embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5')
|
| 41 |
+
llm = genai.GenerativeModel('gemini-1.5-flash-latest')
|
| 42 |
print("β
Models loaded successfully.")
|
| 43 |
|
| 44 |
except Exception as e:
|
| 45 |
print(f"β ERROR: An error occurred during data or model loading: {e}")
|
| 46 |
vector_store_data = None
|
|
|
|
| 47 |
else:
|
| 48 |
+
print(f"β ERROR: Pre-processed data not found in the '{DATA_PATH}' directory.")
|
| 49 |
+
print(f"Please make sure '{vector_store_file}' and '{data_file}' exist.")
|
| 50 |
|
| 51 |
|
| 52 |
+
# Step 5: RAG and Chat Functions (with Updated System Prompt)
|
| 53 |
def get_relevant_context(query, index, top_k=5):
|
| 54 |
query_embedding = embedding_model.encode([query])
|
| 55 |
distances, indices = index.search(query_embedding, top_k)
|
| 56 |
context = []
|
|
|
|
| 57 |
for i in indices[0]:
|
| 58 |
if i < len(texts):
|
| 59 |
context.append({"text": texts[i], "source": sources[i]})
|
|
|
|
| 64 |
relevant_context = get_relevant_context(message, index)
|
| 65 |
context_str = "\n\n".join([f"Source: {c['source']}\nContent: {c['text']}" for c in relevant_context])
|
| 66 |
|
| 67 |
+
# --- NEW, ROBUST SYSTEM PROMPT ---
|
| 68 |
prompt = f"""
|
| 69 |
+
You are a friendly and engaging science communicator for the Halassa Lab at MIT.
|
| 70 |
+
Your goal is to explain the lab's complex computational neuroscience research to a general audience that has little to no scientific background.
|
| 71 |
|
| 72 |
+
Follow these rules strictly:
|
| 73 |
+
1. **Simplify, Don't Dumb Down:** Break down complex topics into simple, easy-to-understand concepts. Use analogies and real-world examples (e.g., "think of the thalamus as a busy switchboard operator for the brain").
|
| 74 |
+
2. **Be Engaging and Accessible:** Write in a clear, conversational, and friendly tone. Avoid jargon at all costs. If you must use a technical term, explain it immediately in simple terms.
|
| 75 |
+
3. **Synthesize and Explain:** Do not just copy-paste from the provided research papers. Read the relevant context and formulate a comprehensive, well-written answer in your own words.
|
| 76 |
+
4. **Cite Your Sources:** At the end of your response, if you used information from the provided papers, you MUST include a "Sources:" list. Use the format [filename - Page X]. This adds credibility.
|
| 77 |
+
5. **Stay Focused:** Only answer questions related to the Halassa Lab's work based on the provided context. If the context doesn't contain the answer, state that the information isn't available in the provided documents.
|
| 78 |
|
| 79 |
Context from the Halassa Lab's papers:
|
| 80 |
---
|
|
|
|
| 83 |
|
| 84 |
User Question: {message}
|
| 85 |
|
| 86 |
+
Your Friendly Explanation:
|
| 87 |
"""
|
| 88 |
try:
|
| 89 |
response = llm.generate_content(prompt)
|
|
|
|
| 91 |
except Exception as e:
|
| 92 |
return f"An error occurred with the AI model: {str(e)}"
|
| 93 |
|
| 94 |
+
# Step 6: Polished Gradio User Interface
|
| 95 |
+
# Define a professional, clean theme for a science lab
|
| 96 |
+
theme = gr.themes.Monochrome(
|
| 97 |
+
primary_hue="indigo",
|
| 98 |
+
secondary_hue="blue",
|
| 99 |
+
neutral_hue="slate",
|
| 100 |
+
font=[gr.themes.GoogleFont("Space Mono"), "ui-sans-serif", "system-ui", "sans-serif"],
|
| 101 |
+
).set(
|
| 102 |
+
body_background_fill="*neutral_50",
|
| 103 |
+
block_background_fill="*neutral_100",
|
| 104 |
+
block_border_width="1px",
|
| 105 |
+
block_shadow="*shadow_md",
|
| 106 |
+
button_shadow="*shadow_md",
|
| 107 |
+
button_primary_background_fill="*primary_500",
|
| 108 |
+
button_primary_background_fill_hover="*primary_600",
|
| 109 |
+
button_primary_text_color="white",
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# List of example questions for users to click on
|
| 113 |
+
example_questions = [
|
| 114 |
+
"What is the main focus of the Halassa Lab's research, in simple terms?",
|
| 115 |
+
"How does the thalamus affect our attention? Can you use an analogy?",
|
| 116 |
+
"What have you learned about sleep and how it affects the brain?",
|
| 117 |
+
"Explain the role of the thalamic reticular nucleus like I'm a high school student.",
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
with gr.Blocks(theme=theme, title="Halassa Lab AI Explainer") as demo:
|
| 121 |
+
with gr.Column():
|
| 122 |
+
# --- Header Section ---
|
| 123 |
+
gr.Image("https://d39w22sdwnt1s2.cloudfront.net/wp-content/uploads/2023/10/Halassa-M-2023-Option-2-1200x675.jpg", show_label=False, show_download_button=False, container=False)
|
| 124 |
+
gr.Markdown("# Welcome to the Halassa Lab AI Explainer")
|
| 125 |
+
gr.Markdown("I'm an AI assistant trained on the Halassa Lab's research. My job is to explain our work on how the brain processes information, pays attention, and sleeps. Ask me a question in plain English, or try one of the examples below!")
|
| 126 |
+
|
| 127 |
+
# --- Chatbot Interface ---
|
| 128 |
+
chatbot_ui = gr.Chatbot(label="Conversation", height=500, layout="panel", bubble_full_width=False)
|
| 129 |
+
message_box = gr.Textbox(label="Ask your question here...", lines=3, placeholder="e.g., How does the brain filter out distractions?")
|
| 130 |
+
|
| 131 |
+
with gr.Row():
|
| 132 |
+
submit_button = gr.Button("Submit", variant="primary", scale=2)
|
| 133 |
+
clear_button = gr.ClearButton(components=[chatbot_ui, message_box], value="Clear Conversation", scale=1)
|
| 134 |
+
|
| 135 |
+
# --- Example Questions Section ---
|
| 136 |
+
gr.Examples(
|
| 137 |
+
examples=example_questions,
|
| 138 |
+
inputs=message_box,
|
| 139 |
+
label="Click an example to get started:"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# --- Logic to make the chatbot respond ---
|
| 143 |
+
def respond(message, history):
|
| 144 |
+
response_text = chat_with_rag(message, history, vector_store_data)
|
| 145 |
+
history.append((message, response_text))
|
| 146 |
+
return "", history
|
| 147 |
+
|
| 148 |
+
submit_button.click(respond, inputs=[message_box, chatbot_ui], outputs=[message_box, chatbot_ui])
|
| 149 |
+
message_box.submit(respond, inputs=[message_box, chatbot_ui], outputs=[message_box, chatbot_ui])
|
| 150 |
+
|
| 151 |
|
| 152 |
# Step 7: Launch the app
|
|
|
|
|
|
|
| 153 |
if vector_store_data:
|
| 154 |
+
demo.launch()
|
| 155 |
+
else:
|
| 156 |
+
# Display a clear error message in the UI if data loading failed
|
| 157 |
+
with gr.Blocks() as demo:
|
| 158 |
+
gr.Markdown("## β οΈ Application Error")
|
| 159 |
+
gr.Markdown("Could not load the necessary data and models. Please check the Hugging Face Space logs for details. This may be due to missing data files or an incorrect API key.")
|
| 160 |
demo.launch()
|