tyang4 commited on
Commit
750a410
·
verified ·
1 Parent(s): 7641e44

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -925
app.py DELETED
@@ -1,925 +0,0 @@
1
- import streamlit as st
2
- st.set_page_config(
3
- page_title="🔬 Explainable Multi-Agent BioData Constructor",
4
- layout="centered",
5
- initial_sidebar_state="collapsed"
6
- )
7
- from neo4j import GraphDatabase
8
- import openai
9
- import pandas as pd
10
- import os
11
- import re
12
- import hashlib
13
- import json
14
- import pydeck as pdk
15
- import faiss
16
- import numpy as np
17
- from sklearn.preprocessing import normalize
18
- from transformers import AutoTokenizer, AutoModel
19
- import torch
20
- import ast
21
- import textwrap
22
- import requests
23
-
24
- # ============================== CONFIGURATION ==============================
25
- NEO4J_URI = st.secrets["NEO4J_URI"]
26
- NEO4J_USERNAME = st.secrets["NEO4J_USERNAME"]
27
- NEO4J_PASSWORD = st.secrets["NEO4J_PASSWORD"]
28
- openai.api_key = st.secrets["openai_api_key"]
29
-
30
- # ============================== DOWNLOAD ==============================
31
- def download_if_missing(url, local_path):
32
- if not os.path.exists(local_path):
33
- with open(local_path, "wb") as f:
34
- f.write(requests.get(url).content)
35
-
36
- base_url = "https://github.com/Tianyu-yang-anna/EcoData-collector/releases/download/v1.0"
37
- files = {
38
- "nodes.csv": "/tmp/nodes.csv",
39
- "nodes_embeddings.npy": "/tmp/nodes_embeddings.npy",
40
- "relationships.csv": "/tmp/relationships.csv",
41
- "relationships_embeddings.npy": "/tmp/relationships_embeddings.npy"
42
- }
43
-
44
- for fname, path in files.items():
45
- download_if_missing(f"{base_url}/{fname}", path)
46
-
47
- # ============================== NEO4J DRIVER ==============================
48
- @st.cache_resource(show_spinner=False)
49
- def create_driver():
50
- try:
51
- driver = GraphDatabase.driver(
52
- NEO4J_URI,
53
- auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
54
- )
55
- with driver.session() as session:
56
- session.run("RETURN 1")
57
- return driver
58
- except Exception as e:
59
- st.error(f"🔴 Neo4j connection failed: {e}")
60
- return None
61
-
62
- driver = create_driver()
63
- # ============================== SIMPLE GPT HELPER ==============================
64
- openai_client = openai.OpenAI(api_key=openai.api_key)
65
-
66
- def gpt_chat(sys_msg: str, user_msg: str, **kwargs):
67
- rsp = openai_client.chat.completions.create(
68
- model="gpt-4o",
69
- messages=[{"role": "system", "content": sys_msg}, {"role": "user", "content": user_msg}],
70
- **kwargs
71
- )
72
- return rsp.choices[0].message.content.strip()
73
-
74
- # ============================== EMBEDDING ENCODER ==============================
75
- class SimpleEncoder:
76
- def __init__(self):
77
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
78
- self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
79
- self.model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2").to(self.device)
80
- self.model.eval()
81
-
82
- def encode(self, texts, batch_size: int = 16):
83
- embeddings = []
84
- for i in range(0, len(texts), batch_size):
85
- batch = texts[i : i + batch_size]
86
- with torch.no_grad():
87
- inputs = self.tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(self.device)
88
- outputs = self.model(**inputs)
89
- batch_emb = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
90
- embeddings.append(batch_emb)
91
- return np.vstack(embeddings)
92
-
93
-
94
- @st.cache_resource(show_spinner=False)
95
- def get_encoder():
96
- return SimpleEncoder()
97
-
98
- # ============================== FAISS INDEX LOADING ==============================
99
- csv_file_pairs = [
100
- ("/tmp/nodes.csv", "/tmp/nodes_embeddings.npy"),
101
- ("/tmp/relationships.csv", "/tmp/relationships_embeddings.npy"),
102
- ]
103
-
104
- for csv_path, npy_path in csv_file_pairs:
105
- if not os.path.exists(npy_path):
106
- st.error(f"❌ Embedding file not found: {npy_path}")
107
- st.stop()
108
-
109
- @st.cache_resource(show_spinner=False)
110
- def load_embeddings_and_faiss_indexes(file_pairs):
111
- index_list, metadatas = [], []
112
- for csv_path, npy_path in file_pairs:
113
- try:
114
- df = pd.read_csv(csv_path).fillna("")
115
- emb = np.load(npy_path).astype("float32")
116
- index = faiss.IndexFlatIP(emb.shape[1])
117
- if faiss.get_num_gpus() > 0:
118
- res = faiss.StandardGpuResources()
119
- index = faiss.index_cpu_to_gpu(res, 0, index)
120
- index.add(emb)
121
- index_list.append(index)
122
- metadatas.append(df)
123
- except Exception as e:
124
- st.warning(f"⚠️ Failed to load {csv_path} / {npy_path}: {e}")
125
- index_list.append(None)
126
- metadatas.append(pd.DataFrame())
127
- return index_list, metadatas
128
-
129
- csv_faiss_indexes, csv_metadatas = load_embeddings_and_faiss_indexes(csv_file_pairs)
130
-
131
- # ============================== DATAFRAME UTILITIES ==============================
132
-
133
- def flatten_props(df: pd.DataFrame) -> pd.DataFrame:
134
- if "props" not in df.columns:
135
- return df
136
- try:
137
- props_df = df["props"].apply(ast.literal_eval).apply(pd.Series)
138
- out = pd.concat([df.drop(columns=["props"]), props_df], axis=1)
139
- # st.write("✅ props flattened, new columns:", list(props_df.columns))
140
- return out
141
- except Exception as e:
142
- st.warning(f"⚠️ Failed to parse props column: {e}")
143
- return df
144
-
145
- def unpack_singletons(df: pd.DataFrame) -> pd.DataFrame:
146
- for col in df.columns:
147
- if df[col].apply(lambda x: isinstance(x, (list, tuple)) and len(x) == 1).any():
148
- df[col] = df[col].apply(lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) == 1 else x)
149
- return df
150
-
151
- def standardize_latlon(df: pd.DataFrame) -> pd.DataFrame:
152
- """
153
- - 统一列名到 latitudes / longitudes
154
- - 若出现同名重复列,保留第一列并删除其余
155
- - longitudes 位置保持不动,把 latitudes 放到其右侧
156
- """
157
- # ---------- ① 统一列名 ----------
158
- col_map = {}
159
- for col in df.columns:
160
- low = col.lower()
161
- if "lat" in low and "lon" not in low:
162
- col_map[col] = "latitudes"
163
- elif ("lon" in low or "lng" in low):
164
- col_map[col] = "longitudes"
165
- df = df.rename(columns=col_map)
166
-
167
- # ---------- ② 处理重复列 ----------
168
- # pandas 会把重名列自动加 .1 .2 …,用 .str.replace 统一判断
169
- while df.columns.duplicated().any():
170
- dup_col = df.columns[df.columns.duplicated()][0]
171
- # 保留出现的第一列,其余同名全部丢掉
172
- first_idx = list(df.columns).index(dup_col)
173
- keep = [True] * len(df.columns)
174
- for i, c in enumerate(df.columns):
175
- if c == dup_col and i != first_idx:
176
- keep[i] = False
177
- df = df.loc[:, keep]
178
-
179
- # ---------- ③ 转数值 ----------
180
- for c in ("latitudes", "longitudes"):
181
- if c in df.columns and not isinstance(df[c], pd.Series):
182
- # 出现重复但未被处理时仍可能是 DataFrame,再取第一列
183
- df[c] = df[c].iloc[:, 0]
184
- if c in df.columns:
185
- df[c] = df[c].apply(
186
- lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) == 1 else x
187
- )
188
- df[c] = pd.to_numeric(df[c], errors="coerce")
189
-
190
- # ---------- ④ 调整顺序:latitudes 紧跟 longitudes ----------
191
- if {"longitudes", "latitudes"}.issubset(df.columns):
192
- cols = list(df.columns)
193
- lon_idx = cols.index("longitudes")
194
- lat_idx = cols.index("latitudes")
195
- if lat_idx != lon_idx + 1:
196
- cols.pop(lat_idx)
197
- cols.insert(lon_idx + 1, "latitudes")
198
- df = df[cols]
199
-
200
- return df
201
-
202
-
203
-
204
- # ===== CSV fallback 查询 =====
205
- @st.cache_data(show_spinner=False)
206
- def rag_csv_fallback(subtask, top_k=2000):
207
- encoder = get_encoder()
208
- query_vec = encoder.encode([subtask])
209
- query_vec = normalize(query_vec, axis=1).astype("float32")
210
- if not np.any(query_vec):
211
- return pd.DataFrame()
212
- all_results = []
213
- for index, metadata in zip(csv_faiss_indexes, csv_metadatas):
214
- if index is None or metadata.empty:
215
- continue
216
- distances, indices = index.search(query_vec, top_k)
217
- retrieved = metadata.iloc[indices[0]].copy()
218
- all_results.append(retrieved)
219
- if all_results:
220
- return pd.concat(all_results).drop_duplicates().reset_index(drop=True)
221
- return pd.DataFrame()
222
-
223
-
224
-
225
- def generate_cypher_with_gpt(subtask: str) -> str:
226
- prompt = f"""
227
- You are an expert Cypher query generator for a Neo4j biodiversity database. The schema is as follows:
228
-
229
- Node Types and Properties:
230
- - Observation: animal_name, date, latitude, longitude
231
- - Species: name, species_full_name
232
- - Site: name
233
- - County: name
234
- - State: name
235
- - Hurricane: name
236
- - Policy: title, description
237
- - ClimateEvent: event_type, date
238
- - TemperatureReading: value, date, location
239
- - Precipitation: amount, date, location
240
-
241
- Relationship Types:
242
- - OBSERVED_IN: (Observation)-[:OBSERVED_IN]->(Site)
243
- - OBSERVED_ORGANISM: (Observation)-[:OBSERVED_ORGANISM]->(Species)
244
- - BELONGS_TO: (Site)-[:BELONGS_TO]->(County)
245
- - IN_COUNTY: (Observation)-[:IN_COUNTY]->(County)
246
- - IN_STATE: (County)-[:IN_STATE]->(State)
247
- - interactsWith: (Species)-[:interactsWith]->(Species)
248
- - preysOn: (Species)-[:preysOn]->(Species)
249
-
250
- Your task is to generate a **precise and efficient** Cypher query for the following subtask:
251
- "{subtask}"
252
-
253
- Guidelines:
254
- - Do NOT return all nodes of a type (e.g., all Species) unless the subtask explicitly asks for it.
255
- - If a location (county/state) is mentioned or implied, include location filtering using IN_COUNTY, IN_STATE, or BELONGS_TO.
256
- - If the subtask implies a taxonomic or common name group (e.g., frog, snake, salmon), apply CONTAINS or STARTS WITH filters on Species.name or species_full_name, using toLower(...) for case-insensitive matching.
257
- - If the subtask includes a time range, include date filtering.
258
- - Prefer using DISTINCT to avoid redundant results.
259
- - Only return fields that are clearly needed to fulfill the subtask.
260
-
261
- Return your response strictly as a **JSON object** with the following fields:
262
- - "intent": a short description of what the query does
263
- - "cypher_query": the Cypher query
264
- - "fields": a list of returned field names (e.g., ["species", "county", "date"])
265
-
266
- Do not include any explanation or commentary—only return the JSON object.
267
- """
268
- client = openai.OpenAI(api_key=st.secrets["openai_api_key"])
269
- response = client.chat.completions.create(
270
- model="gpt-4o",
271
- messages=[{"role": "user", "content": prompt}],
272
- temperature=0
273
- )
274
- content = response.choices[0].message.content.strip()
275
- content = re.sub(r"^(json|python)?", "", content, flags=re.IGNORECASE).strip()
276
- content = re.sub(r"$", "", content).strip()
277
-
278
- try:
279
- cypher_json = json.loads(content)
280
- return cypher_json["cypher_query"]
281
- except Exception as e:
282
- return ""
283
-
284
-
285
- def intelligent_retriever_agent(subtask, saved_hashes=None):
286
- if saved_hashes is None:
287
- saved_hashes = set()
288
- st.success("🔍 Attempting to retrieve data from the KN-Wildlife knowledge graph…")
289
- cypher_query = generate_cypher_with_gpt(subtask)
290
- cypher_df = pd.DataFrame()
291
- if cypher_query.strip():
292
- st.code(cypher_query, language="cypher")
293
- try:
294
- query = re.sub(r"(?i)LIMIT\s+\d+\s*$", "", cypher_query)
295
- with driver.session() as session:
296
- result = session.run(query)
297
- cypher_df = pd.DataFrame(result.data())
298
- except Exception as e:
299
- st.error(f"🚨 Cypher execution error: {e}")
300
- st.code(query, language="cypher")
301
- # decide fallback
302
- fallback_needed = False
303
- if cypher_df.empty:
304
- # st.warning("⚠️ Cypher query returned no data. Trying CSV fallback…")
305
- fallback_needed = True
306
- else:
307
- df_hash = hashlib.md5(cypher_df.to_csv(index=False).encode()).hexdigest()
308
- st.write(f"ℹ️ Cypher rows: {len(cypher_df)} | duplicate?: {df_hash in saved_hashes}")
309
- if df_hash in saved_hashes or len(cypher_df) < 10:
310
- fallback_needed = True
311
- if fallback_needed:
312
- csv_df = rag_csv_fallback(subtask)
313
- if not csv_df.empty:
314
- csv_df = flatten_props(csv_df)
315
- csv_df = unpack_singletons(csv_df)
316
- csv_df = standardize_latlon(csv_df)
317
- # st.success("✅ CSV fallback successful.")
318
- return csv_df
319
- st.warning("❌ CSV fallback also returned nothing.")
320
- return pd.DataFrame()
321
- # good cypher
322
- st.success("✅ Cypher query successful. Using Cypher result.")
323
- cypher_df = flatten_props(cypher_df)
324
- cypher_df = unpack_singletons(cypher_df)
325
- cypher_df = standardize_latlon(cypher_df)
326
- if "species" not in cypher_df.columns and "animal_name" in cypher_df.columns:
327
- cypher_df["species"] = cypher_df["animal_name"]
328
- if "date" in cypher_df.columns:
329
- cypher_df["date"] = pd.to_datetime(cypher_df["date"], errors="coerce")
330
- cypher_df.rename(columns={"latitudes": "latitude", "longitudes": "longitude", "lat": "latitude", "lon": "longitude"}, inplace=True)
331
- for col in ("latitude", "longitude"):
332
- if col in cypher_df.columns:
333
- cypher_df[col] = pd.to_numeric(cypher_df[col], errors="coerce")
334
- return cypher_df
335
-
336
-
337
- def planner_agent(question: str) -> str:
338
- prompt = f"""
339
- You are a **research‑data planning assistant**.
340
-
341
- ------------------------ 📝 TASK ------------------------
342
- Your job is to list the **separate data sets** a researcher must collect
343
- to answer the research question below.
344
-
345
- *Each data set* should be focused on one clearly defined entity or
346
- phenomenon (e.g. "Tracks of hurricanes affecting Florida since 1950",
347
- "Geo‑tagged snake observations in Florida 2000‑present").
348
-
349
- -------------------- 📋 OUTPUT FORMAT --------------------
350
- Write 1–6 blocks. For **each** block use *all* four lines exactly:
351
-
352
- Dataset Need X: <Concise title, ≤ 10 words>
353
- Description: <Why this data matters—1 short sentence>
354
-
355
- ⚠️ Do NOT add extra lines or markdown.
356
- ⚠️ Keep variable names short; no code blocks; no quotes.
357
-
358
- -------------------- 🔍 RESEARCH QUESTION --------------------
359
- {question}
360
- """
361
- rsp = openai_client.chat.completions.create(
362
- model="gpt-4o",
363
- messages=[
364
- {"role": "system", "content": "You are an expert research planner."},
365
- {"role": "user", "content": prompt}
366
- ],
367
- temperature=0.2
368
- )
369
- return rsp.choices[0].message.content.strip()
370
-
371
-
372
-
373
- def evaluate_dataset_with_gpt(subtask: str, df: pd.DataFrame, client=openai_client) -> str:
374
- max_columns = 50
375
- selected_cols = df.columns[:max_columns]
376
- column_info = {col: str(df[col].dtype) for col in selected_cols}
377
- sample_rows = df.head(3)[selected_cols].to_dict(orient="records") # take 3 example rows
378
-
379
- prompt = f"""
380
- You are a data‑validation assistant. Decide whether the dataset below is useful for the research subtask.
381
-
382
- ===== TASK =====
383
- Subtask: "{subtask}"
384
-
385
- ===== DATASET PREVIEW =====
386
- Schema (first {len(selected_cols)} columns):
387
- {json.dumps(column_info, indent=2)}
388
- Sample rows (3 max):
389
- {json.dumps(sample_rows, indent=2)}
390
-
391
- ===== OUTPUT INSTRUCTIONS (follow strictly) =====
392
- Case A – Relevant:
393
- • Write exactly two sentences, each no more than 30 words.
394
- • Summarize what the dataset contains and why it helps the subtask.
395
- • Do not mention column names or list individual rows.
396
-
397
- Case B – Not relevant:
398
- • Write one or two sentences, each no more than 30 words, **describing only what the dataset contains**.
399
- • Do **not** mention the subtask, relevance, suitability, limitations, or missing information (avoid phrases like “not related,” “does not focus,” “irrelevant,” etc.).
400
- • After the sentences, output the header **Additionally, here are some external resources you might find helpful:** on a new line. Format your output in markdown as:
401
- - [Name of Source](URL)
402
- • Then list 2–3 bullet points, each on its own line, starting with “- ” followed immediately by a URL likely to contain the needed data.
403
- • No additional commentary.
404
-
405
-
406
-
407
- General rules:
408
- Plain text only — no code fences. Markdown link syntax (`[text](url)`) is allowed.
409
- """
410
-
411
- rsp = client.chat.completions.create(
412
- model="gpt-4o",
413
- messages=[{"role": "user", "content": prompt}],
414
- temperature=0.3,
415
- )
416
- return rsp.choices[0].message.content.strip()
417
-
418
- # def evaluate_dataset_with_gpt(subtask: str, df: pd.DataFrame,client=openai_client) -> str:
419
- # # 只选择前 N 个字段,避免超长 token
420
- # max_columns = 10
421
- # selected_columns = df.columns[:max_columns]
422
-
423
- # # 提取字段名及其数据类型
424
- # column_info = {col: str(df[col].dtype) for col in selected_columns}
425
-
426
- # # 提取前 3 行示例
427
- # sample_data = df.head(50)[selected_columns].to_dict(orient="records")
428
-
429
- # # 构建 prompt
430
- # prompt = f"""
431
- # You are a data validation assistant. Your task is to summarize what this dataset represents.
432
-
433
- # Subtask: {subtask}
434
-
435
- # Here are the dataset's column names and data types:
436
- # {json.dumps(column_info, indent=2)}
437
-
438
- # Here are a few sample rows:
439
- # {json.dumps(sample_data, indent=2)}
440
-
441
- # Your response should be concise (2-3 sentences).
442
- # Focus on the dataset's content and how it might help with the subtask.
443
- # Do not list column names or describe individual rows.
444
- # 下面是你的输出格式:
445
- # 如果你判断数据和data needed相关,那么输出2-3 sentences介绍该数据集。
446
- # 如果你判断数据和data needed不相关,那么输出2-4条外部资源的链接。
447
- # """
448
- # # 调用 GPT-4o
449
- # rsp = client.chat.completions.create(
450
- # model="gpt-4o",
451
- # messages=[{"role": "user", "content": prompt}],
452
- # temperature=0.3
453
- # )
454
- # return rsp.choices[0].message.content.strip()
455
-
456
-
457
-
458
-
459
-
460
- def external_resource_recommender(subtask: str, client=openai_client) -> str:
461
- prompt = f"""
462
- You are a helpful assistant for researchers. Please recommend 3 reliable and relevant online datasets or websites that can help with the following subtask:
463
-
464
- "{subtask}"
465
-
466
- Format your output in markdown as:
467
- - [Name of Source](URL)
468
- - [Name of Source](URL)
469
- - [Name of Source](URL)
470
- """
471
- rsp = client.chat.completions.create(
472
- model="gpt-4o",
473
- messages=[{"role": "user", "content": prompt}],
474
- temperature=0.3
475
- )
476
- return rsp.choices[0].message.content.strip()
477
-
478
-
479
-
480
- def fallback_query_router(subtask: str, driver) -> pd.DataFrame:
481
- text = subtask.lower()
482
-
483
- with driver.session() as session:
484
-
485
- # --- 1. 物种“where…observed/found…” ---
486
- if "where" in text and ("observed" in text or "found" in text):
487
- query = """
488
- MATCH (o:Observation)-[:OBSERVED_ORGANISM]->(s:Species)
489
- RETURN s.name AS species, o.site_name AS location, o.date AS date
490
- ORDER BY o.date DESC
491
- """
492
-
493
- # --- 2. before / after 某一年 ---
494
- elif "before" in text or "after" in text:
495
- years = re.findall(r'\b(19|20)\d{2}\b', text)
496
- if years:
497
- op = "<" if "before" in text else ">"
498
- query = f"""
499
- MATCH (o:Observation)-[:OBSERVED_ORGANISM]->(s:Species)
500
- WHERE o.date {op} date('{years[0]}-01-01')
501
- RETURN s.name AS species, o.site_name AS location, o.date AS date
502
- ORDER BY o.date DESC
503
- """
504
- else:
505
- query = "RETURN 1"
506
-
507
- # --- 3. 飓风相关 ---
508
- elif "hurricane" in text:
509
- query = """
510
- MATCH (o:Observation)-[:OBSERVED_AT]->(h:Hurricane),
511
- (o)-[:OBSERVED_ORGANISM]->(s:Species),
512
- (o)-[:OBSERVED_IN]->(site)-[:BELONGS_TO]->(c:County)-[:IN_STATE]->(st:State)
513
- WHERE st.name = 'Florida'
514
- RETURN h.name AS hurricane,
515
- s.name AS species,
516
- site.name AS site,
517
- c.name AS county,
518
- o.date AS date
519
- ORDER BY o.date DESC
520
- """
521
-
522
- # --- 4. 捕食 / predator ---
523
- elif "preys on" in text or "predator" in text:
524
- query = """
525
- MATCH (s1:Species)-[:preysOn]->(s2:Species)
526
- RETURN s1.name AS predator, s2.name AS prey
527
- """
528
-
529
- # --- 5. 默认兜底 ---
530
- else:
531
- query = """
532
- MATCH (o:Observation)
533
- RETURN o.animal_name AS species, o.site_name AS location, o.date AS date
534
- """
535
-
536
- # --- 执行查询 ---
537
- result = session.run(query)
538
- df = pd.DataFrame(result.data())
539
-
540
- if df.empty:
541
- st.info("🌐 I couldn't find relevant data in KN‑Wildlife. Let me check external sources for you...")
542
- suggestions = external_resource_recommender(subtask)
543
- st.markdown(suggestions)
544
-
545
- return df
546
-
547
-
548
- def save_dataset(df: pd.DataFrame, filename: str) -> str:
549
- if len(df) < 10:
550
- st.warning(f"❌ Dataset too small to save: only {len(df)} rows.")
551
- return ""
552
- os.makedirs("saved_datasets", exist_ok=True)
553
- path = f"saved_datasets/{filename}.csv"
554
- if os.path.exists(path):
555
- old_hash = hashlib.md5(open(path, 'rb').read()).hexdigest()
556
- new_hash = hashlib.md5(df.to_csv(index=False).encode()).hexdigest()
557
- if old_hash == new_hash:
558
- st.info(f"ℹ️ Dataset saved: {filename}.csv")
559
- return path
560
- df.to_csv(path, index=False)
561
- st.info(f"✅ Dataset saved: {filename}.csv")
562
- return path
563
- # ===================== CHART SUGGESTION (MODIFIED MAP SECTION) =====================
564
-
565
- def suggest_charts_with_gpt(df: pd.DataFrame) -> str:
566
- """Generate Streamlit chart code for automatic visualisation."""
567
- try:
568
- # st.write("🟢 COLS‑DEBUG:", list(df.columns))
569
-
570
- # Ensure dates are parsed
571
- if "date" in df.columns:
572
- df["date"] = df["date"].apply(lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) == 1 else x)
573
- df["date"] = pd.to_datetime(df["date"], errors="coerce")
574
-
575
- if "animal_name" in df.columns and "species" not in df.columns:
576
- df["species"] = df["animal_name"]
577
-
578
- df.rename(columns={"latitudes": "latitude", "longitudes": "longitude"}, inplace=True)
579
-
580
- chart_code = """
581
- # --- Species Bar Chart ---
582
- if 'species' in df.columns:
583
- st.markdown('📊 Count of Observations by Species')
584
- try:
585
- species_counts = df['species'].astype(str).value_counts()
586
- st.bar_chart(species_counts)
587
- except Exception as e:
588
- st.warning(f'⚠️ Could not render species chart: {e}')
589
-
590
- # --- Timeline Line Chart ---
591
- if 'date' in df.columns:
592
- st.markdown('📈 Observations Over Time')
593
- try:
594
- timeline = df['date'].dropna().value_counts().sort_index()
595
- st.line_chart(timeline)
596
- except Exception as e:
597
- st.warning(f'⚠️ Could not render date chart: {e}')
598
-
599
- # --- Map Visualisation (highlight all points) ---
600
- if 'latitude' in df.columns and 'longitude' in df.columns:
601
- st.markdown('🗺️ Observation Locations on Map')
602
- try:
603
- coords = (
604
- df[['latitude', 'longitude']]
605
- .apply(pd.to_numeric, errors='coerce')
606
- .dropna()
607
- .rename(columns={'latitude': 'lat', 'longitude': 'lon'})
608
- )
609
- coords = coords[
610
- (coords['lat'].between(-90, 90)) &
611
- (coords['lon'].between(-180, 180))
612
- ]
613
- if len(coords) == 0:
614
- st.warning('⚠️ No valid coordinates to plot on the map.')
615
- else:
616
- # ---------- ① 视图 ----------
617
- try:
618
- vs_tmp = pdk.data_utils.compute_view(coords[['lon', 'lat']])
619
- view_state = (
620
- pdk.ViewState(**vs_tmp, pitch=0, bearing=0)
621
- if isinstance(vs_tmp, dict) else vs_tmp
622
- )
623
- view_state.pitch = 0
624
- view_state.bearing = 0
625
- except Exception:
626
- view_state = pdk.ViewState(
627
- latitude=coords['lat'].mean(),
628
- longitude=coords['lon'].mean(),
629
- zoom=5,
630
- pitch=0,
631
- bearing=0,
632
- )
633
-
634
- # ---------- ② 高亮层 ----------
635
- layer = pdk.Layer(
636
- 'ScatterplotLayer',
637
- data=coords,
638
- get_position='[lon, lat]',
639
- get_radius=50000,
640
- get_fill_color=[0, 255, 0, 200],
641
- get_line_color=[255, 255, 255],
642
- line_width_units='pixels',
643
- get_line_width=2,
644
- pickable=True,
645
- auto_highlight=True,
646
- )
647
-
648
- # ---------- ③ 组合 Deck ----------
649
- deck = pdk.Deck(
650
- layers=[layer],
651
- initial_view_state=view_state,
652
- map_style='mapbox://styles/mapbox/light-v11',
653
- tooltip={'html': '<b>Lat:</b> {lat}<br/><b>Lon:</b> {lon}'},
654
- )
655
- st.pydeck_chart(deck)
656
- except Exception as e:
657
- st.warning(f'⚠️ Could not render map: {e}')
658
- """
659
- return textwrap.dedent(chart_code)
660
- except Exception as outer_error:
661
- return f"st.warning('❌ Chart generation failed: {outer_error}')"
662
-
663
-
664
-
665
-
666
- # ========= UI layout and connection ==========
667
- if "chat_history" not in st.session_state:
668
- st.session_state.chat_history = []
669
-
670
- # st.set_page_config(
671
- # page_title="🔬 Explainable Multi-Agent BioData Constructor",
672
- # layout="centered",
673
- # initial_sidebar_state="collapsed"
674
- # )
675
-
676
- # ——— 自定义主容器最大宽度 ———
677
- st.markdown(
678
- """
679
- <style>
680
- /* 针对正文文字 */
681
- html, body, .block-container, .markdown-text-container {
682
- font-size: 19px !important; /* ← 这里改数字 */
683
- line-height: 1.6 !important;
684
- }
685
- /* 把默认窄屏的 max-width(约700px)改成 1400px,视需要可调整 */
686
- .block-container {
687
- max-width: 1600px;
688
- }
689
- </style>
690
- """,
691
- unsafe_allow_html=True
692
- )
693
-
694
- st.title("🔬 EcoData collector")
695
-
696
-
697
- st.success("""
698
- 👋 Hi there! I’m **Lily**, your research assistant bot 🤖. I’m here to help you explore data sources related to your **complex research question**. Let’s work together to find the information you need!
699
-
700
- 💡 You can start by entering a research question like:
701
-
702
- - *In Florida, how do hurricanes affect the distribution of snakes?*
703
- - *How does precipitation impact salmon abundance in freshwater ecosystems?*
704
- - *How do climate change and urbanization jointly affect bird migration and diversity in Florida?*
705
- """)
706
-
707
- if driver:
708
- st.success("🟢 Connected to **KN-Wildlife** — a Neo4j-powered biodiversity graph focused on Florida’s species and ecosystems. I’ll start by checking what relevant data we already have in KN-Wildlife to support your research.")
709
-
710
- else:
711
- st.error("🔴 Failed to connect to KN-Wildlife! Please fix connection first.")
712
- st.stop()
713
-
714
- question = st.text_area("Enter your research question:", "")
715
-
716
- # 初始化状态变量
717
- if "start_clicked" not in st.session_state:
718
- st.session_state.start_clicked = False
719
- if "subtask_plan" not in st.session_state:
720
- st.session_state.subtask_plan = ""
721
- if "ready_to_continue" not in st.session_state:
722
- st.session_state.ready_to_continue = False
723
- if "stop_requested" not in st.session_state:
724
- st.session_state.stop_requested = False
725
- if "visualization_ready" not in st.session_state:
726
- st.session_state.visualization_ready = False
727
- if "do_visualize" not in st.session_state:
728
- st.session_state.do_visualize = False
729
- if "all_dataframes" not in st.session_state:
730
- st.session_state.all_dataframes = []
731
- if "retrieval_done" not in st.session_state:
732
- st.session_state.retrieval_done = False
733
-
734
- # 点击按钮,触发子任务分解
735
- if st.button("Let’s start") and question.strip():
736
- st.session_state.start_clicked = True
737
- st.session_state.subtask_plan = planner_agent(question)
738
- st.session_state.ready_to_continue = False
739
- st.session_state.stop_requested = False
740
- st.session_state.visualization_ready = False
741
- st.session_state.do_visualize = False
742
- st.session_state.all_dataframes = []
743
- st.session_state.retrieval_done = False
744
-
745
- # 阶段一:展示子任务
746
- if st.session_state.start_clicked:
747
- # st.success("🧠 Now, I’ll break down your research question into several focused subtasks.")
748
- st.success("🧠 I’ve identified the distinct datasets you’ll need for this research question.")
749
- with st.expander("🔹 Curious how I split your question? Click to see!", expanded=True):
750
- st.write(st.session_state.subtask_plan)
751
-
752
- st.success("📌 I’m ready to roll up my sleeves — shall I start finding datasets for each subtask? 🕒 This step might take a little while, so thanks for your patience!")
753
-
754
- col1, col2 = st.columns([1, 1])
755
- with col1:
756
- if st.button("✅ Yes, go ahead", key="confirm_button"):
757
- st.session_state.ready_to_continue = True
758
- st.session_state.stop_requested = False
759
- with col2:
760
- if st.button("⛔ No, stop here", key="stop_button"):
761
- st.session_state.ready_to_continue = False
762
- st.session_state.stop_requested = True
763
-
764
-
765
- # ---------- 阶段二:数据检索 & 渲染 ----------
766
- if st.session_state.ready_to_continue:
767
-
768
- # ① 先确定 Planner 输出使用的前缀
769
- # 这里假设只有两种可能:Subtask / Dataset Need
770
- if "Dataset Need" in st.session_state.subtask_plan:
771
- prefix = "Dataset Need"
772
- else:
773
- prefix = "Subtask"
774
-
775
- # ② 用 f-string 拼正则(rf = raw‑formatted)
776
- pattern = rf"{prefix} \d+:.*?(?={prefix} \d+:|$)"
777
- subtasks = re.findall(pattern,
778
- st.session_state.subtask_plan,
779
- flags=re.DOTALL)
780
-
781
- # 如果 Planner 没输出任何块,给个提示
782
- if not subtasks:
783
- st.warning("⚠️ No dataset blocks detected in planner output.")
784
- st.stop()
785
-
786
- # 检索只执行一次
787
- if not st.session_state.retrieval_done: # ★
788
- progress_bar = st.progress(0)
789
- total = len(subtasks)
790
- saved_hashes = set()
791
- st.session_state.all_dataframes = []
792
-
793
-
794
- for idx, subtask in enumerate(subtasks):
795
- # with st.expander(f"🔹 Retrieving data for subtask {idx+1}:", expanded=True):
796
- with st.expander(f"🔹 Retrieving data for dataset need {idx+1}:", expanded=True):
797
- cleaned_subtask = "\n".join(subtask.strip().split("\n")[1:])
798
- st.markdown(cleaned_subtask)
799
-
800
- # ---------- 首次运行:真正检索 ----------
801
- if not st.session_state.retrieval_done: # ★
802
- df = intelligent_retriever_agent(subtask, saved_hashes)
803
-
804
- if not df.empty:
805
- df_hash = hashlib.md5(df.to_csv(index=False).encode()).hexdigest()
806
- if df_hash in saved_hashes:
807
- st.warning("⚠️ This dataset has already been saved — skipping duplicate.")
808
- elif len(df) < 10:
809
- st.warning(f"❌ This dataset is too small — just {len(df)} rows. Skipping save.")
810
- else:
811
- saved_hashes.add(df_hash)
812
- df = flatten_props(df)
813
- df = standardize_latlon(df)
814
- summary = evaluate_dataset_with_gpt(subtask, df)
815
- st.session_state.all_dataframes.append({
816
- "hash": df_hash,
817
- "df": df,
818
- "summary": summary
819
- })
820
- st.dataframe(df.head(50))
821
- save_path = save_dataset(df, f"subtask_{idx+1}")
822
- if save_path:
823
- # summary = evaluate_dataset_with_gpt(subtask, df)
824
- st.markdown("**📝 Dataset Introduction:**")
825
- st.write(summary)
826
- if 'progress_bar' in locals():
827
- progress_bar.progress((idx + 1) / total)
828
-
829
- # ---------- 之后 rerun:只展示 ----------
830
- else: # ★
831
- if idx < len(st.session_state.all_dataframes):
832
- # _hash, df = st.session_state.all_dataframes[idx]
833
- # df = standardize_latlon(df)
834
- # st.dataframe(df.head(50))
835
- entry = st.session_state.all_dataframes[idx] # ➕ 新行
836
- df = standardize_latlon(entry["df"])
837
- st.dataframe(df.head(50))
838
- st.write(entry.get("summary", ""))
839
-
840
- # 检索完成后打标记
841
- if not st.session_state.retrieval_done: # ★
842
- st.session_state.retrieval_done = True
843
- st.session_state.visualization_ready = bool(st.session_state.all_dataframes)
844
-
845
-
846
-
847
- if st.session_state.all_dataframes:
848
- st.session_state.visualization_ready = True
849
- else:
850
- st.success("🎉 All subtasks completed and datasets generated!💡 Feel free to ask me more questions anytime!")
851
- # st.success("🎉 All subtasks completed and datasets generated!")
852
- # st.success("💡 Feel free to ask Lily more questions anytime!")
853
-
854
- # 阶段三:是否进行可视化选择
855
- if st.session_state.visualization_ready and not st.session_state.do_visualize:
856
- st.success("📊 All set! I’ve gathered the datasets. Ready to visualize them?")
857
-
858
- col1, col2 = st.columns([1, 1])
859
- with col1:
860
- if st.button("✅ Yes, go ahead", key="viz_confirm"):
861
- st.session_state.do_visualize = True
862
- with col2:
863
- if st.button("⛔ No, stop here", key="viz_stop"):
864
- st.session_state.visualization_ready = False
865
- st.success("🎉 All subtasks completed and datasets generated!💡 Feel free to ask me more questions anytime!")
866
- # st.success("🎉 All subtasks completed and datasets generated!")
867
- # st.success("💡 Feel free to ask Lily more questions anytime!")
868
-
869
- # 阶段三:数据可视化
870
- if st.session_state.do_visualize:
871
- for i, entry in enumerate(st.session_state.all_dataframes):
872
- df = entry["df"]
873
- summary = entry.get("summary", "")
874
- if len(df) < 10:
875
- continue
876
- with st.expander(f"**🔹 Dataset {i + 1} Visualization**", expanded=True):
877
- st.markdown(f"Dataset {i + 1} Preview")
878
- st.dataframe(df.head(10))
879
- chart_code = suggest_charts_with_gpt(df)
880
- if chart_code:
881
- # st.markdown("🧠 The visualization code:")
882
- # st.code(chart_code, language="python")
883
- try:
884
- exec(chart_code, {"st": st, "pd": pd, "df": df, "pdk": pdk})
885
- except Exception as e:
886
- st.error(f"❌ Error running chart code: {e}")
887
-
888
- st.success("🎉 All subtasks completed and datasets generated!💡 Feel free to ask me more questions anytime!")
889
- # st.success("💡 Feel free to ask me more questions anytime!")
890
-
891
- if st.session_state.stop_requested:
892
- st.info("👍 No problem! You can review the subtasks above or revise your question.")
893
-
894
-
895
-
896
- # —— 在侧边栏插入 ChatGPT 风格聊天面板 ——
897
- with st.sidebar.expander("💬 Chat with Lily", expanded=True):
898
- # 聊天输入框
899
- user_msg = st.chat_input("Type your question here…", key="sidebar_chat_input")
900
- if user_msg:
901
- # 拼当前页面上下文
902
- context_parts = []
903
- if st.session_state.subtask_plan:
904
- context_parts.append("Subtasks:\n" + st.session_state.subtask_plan)
905
- for entry in st.session_state.all_dataframes:
906
- context_parts.append("Data summary:\n" + entry["summary"])
907
- page_context = "\n\n".join(context_parts)
908
-
909
- # 调用 GPT helper
910
- with st.spinner("Lily is thinking…"):
911
- assistant_msg = gpt_chat(
912
- sys_msg=f"You are Lily, a research assistant. Here’s what’s on screen:\n\n{page_context}",
913
- user_msg=user_msg
914
- )
915
-
916
- # 保存对话
917
- st.session_state.chat_history.append({"role": "user", "content": user_msg})
918
- st.session_state.chat_history.append({"role": "assistant", "content": assistant_msg})
919
-
920
- # 渲染历史对话
921
- for msg in st.session_state.chat_history:
922
- if msg["role"] == "user":
923
- st.chat_message("user").write(msg["content"])
924
- else:
925
- st.chat_message("assistant").write(msg["content"])