sahilursa commited on
Commit
9c0a4f9
Β·
verified Β·
1 Parent(s): f0cb788

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -42
app.py CHANGED
@@ -1,5 +1,4 @@
1
  # Step 1: Import Libraries
2
- # No changes needed here, these will be installed from your requirements.txt
3
  import gradio as gr
4
  from sentence_transformers import SentenceTransformer
5
  import faiss
@@ -9,9 +8,7 @@ import pickle
9
  import os
10
 
11
  # Step 2: Configure API Key from Hugging Face Secrets
12
- # This section is modified to securely access the API key from your Space's secrets.
13
  try:
14
- # Make sure to set 'GOOGLE_API_KEY' in your Hugging Face Space's settings/secrets
15
  GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
16
  if GOOGLE_API_KEY is None:
17
  raise ValueError("GOOGLE_API_KEY not found in environment variables.")
@@ -21,18 +18,15 @@ except Exception as e:
21
  print(f"ERROR: Could not configure API key. Please ensure 'GOOGLE_API_KEY' is set in your Hugging Face Space secrets. Details: {e}")
22
 
23
  # Step 3: Define Data Path
24
- # This is updated to point to the 'rag_chatbot_data' folder you uploaded to the Space.
25
- # It no longer uses Google Drive.
26
  DATA_PATH = "data"
27
  vector_store_file = os.path.join(DATA_PATH, "vector_store.index")
28
  data_file = os.path.join(DATA_PATH, "data.pkl")
29
 
30
  # Step 4: Load Models and Pre-processed Data
31
- # This section is slightly restructured to load models only if the data is found.
32
  vector_store_data = None
33
  if os.path.exists(vector_store_file) and os.path.exists(data_file):
34
  try:
35
- print("Loading pre-processed data from the repository...")
36
  index = faiss.read_index(vector_store_file)
37
  with open(data_file, "rb") as f:
38
  stored_data = pickle.load(f)
@@ -40,30 +34,26 @@ if os.path.exists(vector_store_file) and os.path.exists(data_file):
40
  sources = stored_data["sources"]
41
  print("βœ… Data loaded successfully.")
42
 
43
- # Store everything in a state object for Gradio
44
  vector_store_data = (index, texts, sources)
45
 
46
- # Load the embedding and LLM models
47
  print("Loading AI and embedding models...")
48
  embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5')
49
- llm = genai.GenerativeModel('gemini-2.5-flash') # Using gemini-pro as it's a robust choice
50
  print("βœ… Models loaded successfully.")
51
 
52
  except Exception as e:
53
  print(f"❌ ERROR: An error occurred during data or model loading: {e}")
54
  vector_store_data = None
55
-
56
  else:
57
- print("❌ ERROR: Pre-processed data not found in the repository.")
58
- print(f"Please make sure '{vector_store_file}' and '{data_file}' exist in the '{DATA_PATH}' folder.")
59
 
60
 
61
- # Step 5: RAG and Chat Functions (No changes needed here)
62
  def get_relevant_context(query, index, top_k=5):
63
  query_embedding = embedding_model.encode([query])
64
  distances, indices = index.search(query_embedding, top_k)
65
  context = []
66
- # Ensure that the indices are within the bounds of the texts list
67
  for i in indices[0]:
68
  if i < len(texts):
69
  context.append({"text": texts[i], "source": sources[i]})
@@ -74,14 +64,17 @@ def chat_with_rag(message, history, vector_store_data):
74
  relevant_context = get_relevant_context(message, index)
75
  context_str = "\n\n".join([f"Source: {c['source']}\nContent: {c['text']}" for c in relevant_context])
76
 
 
77
  prompt = f"""
78
- You are an Expert scientist in the Halassa Lab at MIT, an expert in computational neuroscience.
79
- Your primary goal is to answer questions as thoroughly and accurately as possible. You may intelligently synthesize information from the provided context, which consists of key papers from the lab.
80
 
81
- Follow these rules strictly if you are using the provided context:
82
- 1. Do not simply copy-paste from the context. Read the relevant passages and formulate a comprehensive, well-written answer in your own words.
83
- 2. When your answer uses information directly from a provided paper, you MUST cite the source at the end of the output in a list. Use the format [filename - Page X].
84
- 3. If the provided papers offer relevant concepts but do not contain the full answer, use your broader knowledge of computational neuroscience to provide a more complete explanation.
 
 
85
 
86
  Context from the Halassa Lab's papers:
87
  ---
@@ -90,7 +83,7 @@ def chat_with_rag(message, history, vector_store_data):
90
 
91
  User Question: {message}
92
 
93
- Expert Answer:
94
  """
95
  try:
96
  response = llm.generate_content(prompt)
@@ -98,27 +91,70 @@ def chat_with_rag(message, history, vector_store_data):
98
  except Exception as e:
99
  return f"An error occurred with the AI model: {str(e)}"
100
 
101
- # Step 6: Gradio User Interface
102
- with gr.Blocks(theme=gr.themes.Soft(), title="Halassa Literature Chatbot") as demo:
103
- gr.Markdown("# Halassa Lab Onboarder")
104
- if vector_store_data is None:
105
- gr.Markdown("## ⚠️ Error: Could not load data or models. Please check the logs in the Hugging Face Space for details.")
106
- else:
107
- gr.Markdown("The documents have been pre-loaded. Ask your questions below.")
108
- chatbot_ui = gr.Chatbot(label="Chat History", height=600, layout="panel")
109
- message_box = gr.Textbox(label="Ask your question...", lines=3)
110
- clear_button = gr.ClearButton(components=[chatbot_ui, message_box])
111
-
112
- def respond(message, history):
113
- # Pass the loaded vector_store_data to the chat function
114
- response_text = chat_with_rag(message, history, vector_store_data)
115
- history.append((message, response_text))
116
- return "", history
117
-
118
- message_box.submit(respond, inputs=[message_box, chatbot_ui], outputs=[message_box, chatbot_ui])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  # Step 7: Launch the app
121
- # On Hugging Face, demo.launch() is all you need.
122
- # It will run the app and make it accessible.
123
  if vector_store_data:
 
 
 
 
 
 
124
  demo.launch()
 
1
  # Step 1: Import Libraries
 
2
  import gradio as gr
3
  from sentence_transformers import SentenceTransformer
4
  import faiss
 
8
  import os
9
 
10
  # Step 2: Configure API Key from Hugging Face Secrets
 
11
  try:
 
12
  GOOGLE_API_KEY = os.environ.get('GOOGLE_API_KEY')
13
  if GOOGLE_API_KEY is None:
14
  raise ValueError("GOOGLE_API_KEY not found in environment variables.")
 
18
  print(f"ERROR: Could not configure API key. Please ensure 'GOOGLE_API_KEY' is set in your Hugging Face Space secrets. Details: {e}")
19
 
20
  # Step 3: Define Data Path
 
 
21
  DATA_PATH = "data"
22
  vector_store_file = os.path.join(DATA_PATH, "vector_store.index")
23
  data_file = os.path.join(DATA_PATH, "data.pkl")
24
 
25
  # Step 4: Load Models and Pre-processed Data
 
26
  vector_store_data = None
27
  if os.path.exists(vector_store_file) and os.path.exists(data_file):
28
  try:
29
+ print(f"Loading pre-processed data from the '{DATA_PATH}' directory...")
30
  index = faiss.read_index(vector_store_file)
31
  with open(data_file, "rb") as f:
32
  stored_data = pickle.load(f)
 
34
  sources = stored_data["sources"]
35
  print("βœ… Data loaded successfully.")
36
 
 
37
  vector_store_data = (index, texts, sources)
38
 
 
39
  print("Loading AI and embedding models...")
40
  embedding_model = SentenceTransformer('BAAI/bge-large-en-v1.5')
41
+ llm = genai.GenerativeModel('gemini-1.5-flash-latest')
42
  print("βœ… Models loaded successfully.")
43
 
44
  except Exception as e:
45
  print(f"❌ ERROR: An error occurred during data or model loading: {e}")
46
  vector_store_data = None
 
47
  else:
48
+ print(f"❌ ERROR: Pre-processed data not found in the '{DATA_PATH}' directory.")
49
+ print(f"Please make sure '{vector_store_file}' and '{data_file}' exist.")
50
 
51
 
52
+ # Step 5: RAG and Chat Functions (with Updated System Prompt)
53
  def get_relevant_context(query, index, top_k=5):
54
  query_embedding = embedding_model.encode([query])
55
  distances, indices = index.search(query_embedding, top_k)
56
  context = []
 
57
  for i in indices[0]:
58
  if i < len(texts):
59
  context.append({"text": texts[i], "source": sources[i]})
 
64
  relevant_context = get_relevant_context(message, index)
65
  context_str = "\n\n".join([f"Source: {c['source']}\nContent: {c['text']}" for c in relevant_context])
66
 
67
+ # --- NEW, ROBUST SYSTEM PROMPT ---
68
  prompt = f"""
69
+ You are a friendly and engaging science communicator for the Halassa Lab at MIT.
70
+ Your goal is to explain the lab's complex computational neuroscience research to a general audience that has little to no scientific background.
71
 
72
+ Follow these rules strictly:
73
+ 1. **Simplify, Don't Dumb Down:** Break down complex topics into simple, easy-to-understand concepts. Use analogies and real-world examples (e.g., "think of the thalamus as a busy switchboard operator for the brain").
74
+ 2. **Be Engaging and Accessible:** Write in a clear, conversational, and friendly tone. Avoid jargon at all costs. If you must use a technical term, explain it immediately in simple terms.
75
+ 3. **Synthesize and Explain:** Do not just copy-paste from the provided research papers. Read the relevant context and formulate a comprehensive, well-written answer in your own words.
76
+ 4. **Cite Your Sources:** At the end of your response, if you used information from the provided papers, you MUST include a "Sources:" list. Use the format [filename - Page X]. This adds credibility.
77
+ 5. **Stay Focused:** Only answer questions related to the Halassa Lab's work based on the provided context. If the context doesn't contain the answer, state that the information isn't available in the provided documents.
78
 
79
  Context from the Halassa Lab's papers:
80
  ---
 
83
 
84
  User Question: {message}
85
 
86
+ Your Friendly Explanation:
87
  """
88
  try:
89
  response = llm.generate_content(prompt)
 
91
  except Exception as e:
92
  return f"An error occurred with the AI model: {str(e)}"
93
 
94
+ # Step 6: Polished Gradio User Interface
95
+ # Define a professional, clean theme for a science lab
96
+ theme = gr.themes.Monochrome(
97
+ primary_hue="indigo",
98
+ secondary_hue="blue",
99
+ neutral_hue="slate",
100
+ font=[gr.themes.GoogleFont("Space Mono"), "ui-sans-serif", "system-ui", "sans-serif"],
101
+ ).set(
102
+ body_background_fill="*neutral_50",
103
+ block_background_fill="*neutral_100",
104
+ block_border_width="1px",
105
+ block_shadow="*shadow_md",
106
+ button_shadow="*shadow_md",
107
+ button_primary_background_fill="*primary_500",
108
+ button_primary_background_fill_hover="*primary_600",
109
+ button_primary_text_color="white",
110
+ )
111
+
112
+ # List of example questions for users to click on
113
+ example_questions = [
114
+ "What is the main focus of the Halassa Lab's research, in simple terms?",
115
+ "How does the thalamus affect our attention? Can you use an analogy?",
116
+ "What have you learned about sleep and how it affects the brain?",
117
+ "Explain the role of the thalamic reticular nucleus like I'm a high school student.",
118
+ ]
119
+
120
+ with gr.Blocks(theme=theme, title="Halassa Lab AI Explainer") as demo:
121
+ with gr.Column():
122
+ # --- Header Section ---
123
+ gr.Image("https://d39w22sdwnt1s2.cloudfront.net/wp-content/uploads/2023/10/Halassa-M-2023-Option-2-1200x675.jpg", show_label=False, show_download_button=False, container=False)
124
+ gr.Markdown("# Welcome to the Halassa Lab AI Explainer")
125
+ gr.Markdown("I'm an AI assistant trained on the Halassa Lab's research. My job is to explain our work on how the brain processes information, pays attention, and sleeps. Ask me a question in plain English, or try one of the examples below!")
126
+
127
+ # --- Chatbot Interface ---
128
+ chatbot_ui = gr.Chatbot(label="Conversation", height=500, layout="panel", bubble_full_width=False)
129
+ message_box = gr.Textbox(label="Ask your question here...", lines=3, placeholder="e.g., How does the brain filter out distractions?")
130
+
131
+ with gr.Row():
132
+ submit_button = gr.Button("Submit", variant="primary", scale=2)
133
+ clear_button = gr.ClearButton(components=[chatbot_ui, message_box], value="Clear Conversation", scale=1)
134
+
135
+ # --- Example Questions Section ---
136
+ gr.Examples(
137
+ examples=example_questions,
138
+ inputs=message_box,
139
+ label="Click an example to get started:"
140
+ )
141
+
142
+ # --- Logic to make the chatbot respond ---
143
+ def respond(message, history):
144
+ response_text = chat_with_rag(message, history, vector_store_data)
145
+ history.append((message, response_text))
146
+ return "", history
147
+
148
+ submit_button.click(respond, inputs=[message_box, chatbot_ui], outputs=[message_box, chatbot_ui])
149
+ message_box.submit(respond, inputs=[message_box, chatbot_ui], outputs=[message_box, chatbot_ui])
150
+
151
 
152
  # Step 7: Launch the app
 
 
153
  if vector_store_data:
154
+ demo.launch()
155
+ else:
156
+ # Display a clear error message in the UI if data loading failed
157
+ with gr.Blocks() as demo:
158
+ gr.Markdown("## ⚠️ Application Error")
159
+ gr.Markdown("Could not load the necessary data and models. Please check the Hugging Face Space logs for details. This may be due to missing data files or an incorrect API key.")
160
  demo.launch()