tyang4 commited on
Commit
fa31f00
·
verified ·
1 Parent(s): d43d402

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +925 -0
  2. requirements.txt +10 -2
app.py ADDED
@@ -0,0 +1,925 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ st.set_page_config(
3
+ page_title="🔬 Explainable Multi-Agent BioData Constructor",
4
+ layout="centered",
5
+ initial_sidebar_state="collapsed"
6
+ )
7
+ from neo4j import GraphDatabase
8
+ import openai
9
+ import pandas as pd
10
+ import os
11
+ import re
12
+ import hashlib
13
+ import json
14
+ import pydeck as pdk
15
+ import faiss
16
+ import numpy as np
17
+ from sklearn.preprocessing import normalize
18
+ from transformers import AutoTokenizer, AutoModel
19
+ import torch
20
+ import ast
21
+ import textwrap
22
+ import requests
23
+
24
+ # ============================== CONFIGURATION ==============================
25
+ NEO4J_URI = st.secrets["NEO4J_URI"]
26
+ NEO4J_USERNAME = st.secrets["NEO4J_USERNAME"]
27
+ NEO4J_PASSWORD = st.secrets["NEO4J_PASSWORD"]
28
+ openai.api_key = st.secrets["openai_api_key"]
29
+
30
+ # ============================== DOWNLOAD ==============================
31
+ def download_if_missing(url, local_path):
32
+ if not os.path.exists(local_path):
33
+ with open(local_path, "wb") as f:
34
+ f.write(requests.get(url).content)
35
+
36
+ base_url = "https://github.com/Tianyu-yang-anna/EcoData-collector/releases/download/v1.0"
37
+ files = {
38
+ "nodes.csv": "/tmp/nodes.csv",
39
+ "nodes_embeddings.npy": "/tmp/nodes_embeddings.npy",
40
+ "relationships.csv": "/tmp/relationships.csv",
41
+ "relationships_embeddings.npy": "/tmp/relationships_embeddings.npy"
42
+ }
43
+
44
+ for fname, path in files.items():
45
+ download_if_missing(f"{base_url}/{fname}", path)
46
+
47
+ # ============================== NEO4J DRIVER ==============================
48
+ @st.cache_resource(show_spinner=False)
49
+ def create_driver():
50
+ try:
51
+ driver = GraphDatabase.driver(
52
+ NEO4J_URI,
53
+ auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
54
+ )
55
+ with driver.session() as session:
56
+ session.run("RETURN 1")
57
+ return driver
58
+ except Exception as e:
59
+ st.error(f"🔴 Neo4j connection failed: {e}")
60
+ return None
61
+
62
+ driver = create_driver()
63
+ # ============================== SIMPLE GPT HELPER ==============================
64
+ openai_client = openai.OpenAI(api_key=openai.api_key)
65
+
66
+ def gpt_chat(sys_msg: str, user_msg: str, **kwargs):
67
+ rsp = openai_client.chat.completions.create(
68
+ model="gpt-4o",
69
+ messages=[{"role": "system", "content": sys_msg}, {"role": "user", "content": user_msg}],
70
+ **kwargs
71
+ )
72
+ return rsp.choices[0].message.content.strip()
73
+
74
+ # ============================== EMBEDDING ENCODER ==============================
75
+ class SimpleEncoder:
76
+ def __init__(self):
77
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
78
+ self.tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
79
+ self.model = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2").to(self.device)
80
+ self.model.eval()
81
+
82
+ def encode(self, texts, batch_size: int = 16):
83
+ embeddings = []
84
+ for i in range(0, len(texts), batch_size):
85
+ batch = texts[i : i + batch_size]
86
+ with torch.no_grad():
87
+ inputs = self.tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(self.device)
88
+ outputs = self.model(**inputs)
89
+ batch_emb = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
90
+ embeddings.append(batch_emb)
91
+ return np.vstack(embeddings)
92
+
93
+
94
+ @st.cache_resource(show_spinner=False)
95
+ def get_encoder():
96
+ return SimpleEncoder()
97
+
98
+ # ============================== FAISS INDEX LOADING ==============================
99
+ csv_file_pairs = [
100
+ ("/tmp/nodes.csv", "/tmp/nodes_embeddings.npy"),
101
+ ("/tmp/relationships.csv", "/tmp/relationships_embeddings.npy"),
102
+ ]
103
+
104
+ for csv_path, npy_path in csv_file_pairs:
105
+ if not os.path.exists(npy_path):
106
+ st.error(f"❌ Embedding file not found: {npy_path}")
107
+ st.stop()
108
+
109
+ @st.cache_resource(show_spinner=False)
110
+ def load_embeddings_and_faiss_indexes(file_pairs):
111
+ index_list, metadatas = [], []
112
+ for csv_path, npy_path in file_pairs:
113
+ try:
114
+ df = pd.read_csv(csv_path).fillna("")
115
+ emb = np.load(npy_path).astype("float32")
116
+ index = faiss.IndexFlatIP(emb.shape[1])
117
+ if faiss.get_num_gpus() > 0:
118
+ res = faiss.StandardGpuResources()
119
+ index = faiss.index_cpu_to_gpu(res, 0, index)
120
+ index.add(emb)
121
+ index_list.append(index)
122
+ metadatas.append(df)
123
+ except Exception as e:
124
+ st.warning(f"⚠️ Failed to load {csv_path} / {npy_path}: {e}")
125
+ index_list.append(None)
126
+ metadatas.append(pd.DataFrame())
127
+ return index_list, metadatas
128
+
129
+ csv_faiss_indexes, csv_metadatas = load_embeddings_and_faiss_indexes(csv_file_pairs)
130
+
131
+ # ============================== DATAFRAME UTILITIES ==============================
132
+
133
+ def flatten_props(df: pd.DataFrame) -> pd.DataFrame:
134
+ if "props" not in df.columns:
135
+ return df
136
+ try:
137
+ props_df = df["props"].apply(ast.literal_eval).apply(pd.Series)
138
+ out = pd.concat([df.drop(columns=["props"]), props_df], axis=1)
139
+ # st.write("✅ props flattened, new columns:", list(props_df.columns))
140
+ return out
141
+ except Exception as e:
142
+ st.warning(f"⚠️ Failed to parse props column: {e}")
143
+ return df
144
+
145
+ def unpack_singletons(df: pd.DataFrame) -> pd.DataFrame:
146
+ for col in df.columns:
147
+ if df[col].apply(lambda x: isinstance(x, (list, tuple)) and len(x) == 1).any():
148
+ df[col] = df[col].apply(lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) == 1 else x)
149
+ return df
150
+
151
+ def standardize_latlon(df: pd.DataFrame) -> pd.DataFrame:
152
+ """
153
+ - 统一列名到 latitudes / longitudes
154
+ - 若出现同名重复列,保留第一列并删除其余
155
+ - longitudes 位置保持不动,把 latitudes 放到其右侧
156
+ """
157
+ # ---------- ① 统一列名 ----------
158
+ col_map = {}
159
+ for col in df.columns:
160
+ low = col.lower()
161
+ if "lat" in low and "lon" not in low:
162
+ col_map[col] = "latitudes"
163
+ elif ("lon" in low or "lng" in low):
164
+ col_map[col] = "longitudes"
165
+ df = df.rename(columns=col_map)
166
+
167
+ # ---------- ② 处理重复列 ----------
168
+ # pandas 会把重名列自动加 .1 .2 …,用 .str.replace 统一判断
169
+ while df.columns.duplicated().any():
170
+ dup_col = df.columns[df.columns.duplicated()][0]
171
+ # 保留出现的第一列,其余同名全部丢掉
172
+ first_idx = list(df.columns).index(dup_col)
173
+ keep = [True] * len(df.columns)
174
+ for i, c in enumerate(df.columns):
175
+ if c == dup_col and i != first_idx:
176
+ keep[i] = False
177
+ df = df.loc[:, keep]
178
+
179
+ # ---------- ③ 转数值 ----------
180
+ for c in ("latitudes", "longitudes"):
181
+ if c in df.columns and not isinstance(df[c], pd.Series):
182
+ # 出现重复但未被处理时仍可能是 DataFrame,再取第一列
183
+ df[c] = df[c].iloc[:, 0]
184
+ if c in df.columns:
185
+ df[c] = df[c].apply(
186
+ lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) == 1 else x
187
+ )
188
+ df[c] = pd.to_numeric(df[c], errors="coerce")
189
+
190
+ # ---------- ④ 调整顺序:latitudes 紧跟 longitudes ----------
191
+ if {"longitudes", "latitudes"}.issubset(df.columns):
192
+ cols = list(df.columns)
193
+ lon_idx = cols.index("longitudes")
194
+ lat_idx = cols.index("latitudes")
195
+ if lat_idx != lon_idx + 1:
196
+ cols.pop(lat_idx)
197
+ cols.insert(lon_idx + 1, "latitudes")
198
+ df = df[cols]
199
+
200
+ return df
201
+
202
+
203
+
204
+ # ===== CSV fallback 查询 =====
205
+ @st.cache_data(show_spinner=False)
206
+ def rag_csv_fallback(subtask, top_k=2000):
207
+ encoder = get_encoder()
208
+ query_vec = encoder.encode([subtask])
209
+ query_vec = normalize(query_vec, axis=1).astype("float32")
210
+ if not np.any(query_vec):
211
+ return pd.DataFrame()
212
+ all_results = []
213
+ for index, metadata in zip(csv_faiss_indexes, csv_metadatas):
214
+ if index is None or metadata.empty:
215
+ continue
216
+ distances, indices = index.search(query_vec, top_k)
217
+ retrieved = metadata.iloc[indices[0]].copy()
218
+ all_results.append(retrieved)
219
+ if all_results:
220
+ return pd.concat(all_results).drop_duplicates().reset_index(drop=True)
221
+ return pd.DataFrame()
222
+
223
+
224
+
225
+ def generate_cypher_with_gpt(subtask: str) -> str:
226
+ prompt = f"""
227
+ You are an expert Cypher query generator for a Neo4j biodiversity database. The schema is as follows:
228
+
229
+ Node Types and Properties:
230
+ - Observation: animal_name, date, latitude, longitude
231
+ - Species: name, species_full_name
232
+ - Site: name
233
+ - County: name
234
+ - State: name
235
+ - Hurricane: name
236
+ - Policy: title, description
237
+ - ClimateEvent: event_type, date
238
+ - TemperatureReading: value, date, location
239
+ - Precipitation: amount, date, location
240
+
241
+ Relationship Types:
242
+ - OBSERVED_IN: (Observation)-[:OBSERVED_IN]->(Site)
243
+ - OBSERVED_ORGANISM: (Observation)-[:OBSERVED_ORGANISM]->(Species)
244
+ - BELONGS_TO: (Site)-[:BELONGS_TO]->(County)
245
+ - IN_COUNTY: (Observation)-[:IN_COUNTY]->(County)
246
+ - IN_STATE: (County)-[:IN_STATE]->(State)
247
+ - interactsWith: (Species)-[:interactsWith]->(Species)
248
+ - preysOn: (Species)-[:preysOn]->(Species)
249
+
250
+ Your task is to generate a **precise and efficient** Cypher query for the following subtask:
251
+ "{subtask}"
252
+
253
+ Guidelines:
254
+ - Do NOT return all nodes of a type (e.g., all Species) unless the subtask explicitly asks for it.
255
+ - If a location (county/state) is mentioned or implied, include location filtering using IN_COUNTY, IN_STATE, or BELONGS_TO.
256
+ - If the subtask implies a taxonomic or common name group (e.g., frog, snake, salmon), apply CONTAINS or STARTS WITH filters on Species.name or species_full_name, using toLower(...) for case-insensitive matching.
257
+ - If the subtask includes a time range, include date filtering.
258
+ - Prefer using DISTINCT to avoid redundant results.
259
+ - Only return fields that are clearly needed to fulfill the subtask.
260
+
261
+ Return your response strictly as a **JSON object** with the following fields:
262
+ - "intent": a short description of what the query does
263
+ - "cypher_query": the Cypher query
264
+ - "fields": a list of returned field names (e.g., ["species", "county", "date"])
265
+
266
+ Do not include any explanation or commentary—only return the JSON object.
267
+ """
268
+ client = openai.OpenAI(api_key=st.secrets["openai_api_key"])
269
+ response = client.chat.completions.create(
270
+ model="gpt-4o",
271
+ messages=[{"role": "user", "content": prompt}],
272
+ temperature=0
273
+ )
274
+ content = response.choices[0].message.content.strip()
275
+ content = re.sub(r"^(json|python)?", "", content, flags=re.IGNORECASE).strip()
276
+ content = re.sub(r"$", "", content).strip()
277
+
278
+ try:
279
+ cypher_json = json.loads(content)
280
+ return cypher_json["cypher_query"]
281
+ except Exception as e:
282
+ return ""
283
+
284
+
285
+ def intelligent_retriever_agent(subtask, saved_hashes=None):
286
+ if saved_hashes is None:
287
+ saved_hashes = set()
288
+ st.success("🔍 Attempting to retrieve data from the KN-Wildlife knowledge graph…")
289
+ cypher_query = generate_cypher_with_gpt(subtask)
290
+ cypher_df = pd.DataFrame()
291
+ if cypher_query.strip():
292
+ st.code(cypher_query, language="cypher")
293
+ try:
294
+ query = re.sub(r"(?i)LIMIT\s+\d+\s*$", "", cypher_query)
295
+ with driver.session() as session:
296
+ result = session.run(query)
297
+ cypher_df = pd.DataFrame(result.data())
298
+ except Exception as e:
299
+ st.error(f"🚨 Cypher execution error: {e}")
300
+ st.code(query, language="cypher")
301
+ # decide fallback
302
+ fallback_needed = False
303
+ if cypher_df.empty:
304
+ # st.warning("⚠️ Cypher query returned no data. Trying CSV fallback…")
305
+ fallback_needed = True
306
+ else:
307
+ df_hash = hashlib.md5(cypher_df.to_csv(index=False).encode()).hexdigest()
308
+ st.write(f"ℹ️ Cypher rows: {len(cypher_df)} | duplicate?: {df_hash in saved_hashes}")
309
+ if df_hash in saved_hashes or len(cypher_df) < 10:
310
+ fallback_needed = True
311
+ if fallback_needed:
312
+ csv_df = rag_csv_fallback(subtask)
313
+ if not csv_df.empty:
314
+ csv_df = flatten_props(csv_df)
315
+ csv_df = unpack_singletons(csv_df)
316
+ csv_df = standardize_latlon(csv_df)
317
+ # st.success("✅ CSV fallback successful.")
318
+ return csv_df
319
+ st.warning("❌ CSV fallback also returned nothing.")
320
+ return pd.DataFrame()
321
+ # good cypher
322
+ st.success("✅ Cypher query successful. Using Cypher result.")
323
+ cypher_df = flatten_props(cypher_df)
324
+ cypher_df = unpack_singletons(cypher_df)
325
+ cypher_df = standardize_latlon(cypher_df)
326
+ if "species" not in cypher_df.columns and "animal_name" in cypher_df.columns:
327
+ cypher_df["species"] = cypher_df["animal_name"]
328
+ if "date" in cypher_df.columns:
329
+ cypher_df["date"] = pd.to_datetime(cypher_df["date"], errors="coerce")
330
+ cypher_df.rename(columns={"latitudes": "latitude", "longitudes": "longitude", "lat": "latitude", "lon": "longitude"}, inplace=True)
331
+ for col in ("latitude", "longitude"):
332
+ if col in cypher_df.columns:
333
+ cypher_df[col] = pd.to_numeric(cypher_df[col], errors="coerce")
334
+ return cypher_df
335
+
336
+
337
+ def planner_agent(question: str) -> str:
338
+ prompt = f"""
339
+ You are a **research‑data planning assistant**.
340
+
341
+ ------------------------ 📝 TASK ------------------------
342
+ Your job is to list the **separate data sets** a researcher must collect
343
+ to answer the research question below.
344
+
345
+ *Each data set* should be focused on one clearly defined entity or
346
+ phenomenon (e.g. "Tracks of hurricanes affecting Florida since 1950",
347
+ "Geo‑tagged snake observations in Florida 2000‑present").
348
+
349
+ -------------------- 📋 OUTPUT FORMAT --------------------
350
+ Write 1–6 blocks. For **each** block use *all* four lines exactly:
351
+
352
+ Dataset Need X: <Concise title, ≤ 10 words>
353
+ Description: <Why this data matters—1 short sentence>
354
+
355
+ ⚠️ Do NOT add extra lines or markdown.
356
+ ⚠️ Keep variable names short; no code blocks; no quotes.
357
+
358
+ -------------------- 🔍 RESEARCH QUESTION --------------------
359
+ {question}
360
+ """
361
+ rsp = openai_client.chat.completions.create(
362
+ model="gpt-4o",
363
+ messages=[
364
+ {"role": "system", "content": "You are an expert research planner."},
365
+ {"role": "user", "content": prompt}
366
+ ],
367
+ temperature=0.2
368
+ )
369
+ return rsp.choices[0].message.content.strip()
370
+
371
+
372
+
373
+ def evaluate_dataset_with_gpt(subtask: str, df: pd.DataFrame, client=openai_client) -> str:
374
+ max_columns = 50
375
+ selected_cols = df.columns[:max_columns]
376
+ column_info = {col: str(df[col].dtype) for col in selected_cols}
377
+ sample_rows = df.head(3)[selected_cols].to_dict(orient="records") # take 3 example rows
378
+
379
+ prompt = f"""
380
+ You are a data‑validation assistant. Decide whether the dataset below is useful for the research subtask.
381
+
382
+ ===== TASK =====
383
+ Subtask: "{subtask}"
384
+
385
+ ===== DATASET PREVIEW =====
386
+ Schema (first {len(selected_cols)} columns):
387
+ {json.dumps(column_info, indent=2)}
388
+ Sample rows (3 max):
389
+ {json.dumps(sample_rows, indent=2)}
390
+
391
+ ===== OUTPUT INSTRUCTIONS (follow strictly) =====
392
+ Case A – Relevant:
393
+ • Write exactly two sentences, each no more than 30 words.
394
+ • Summarize what the dataset contains and why it helps the subtask.
395
+ • Do not mention column names or list individual rows.
396
+
397
+ Case B – Not relevant:
398
+ • Write one or two sentences, each no more than 30 words, **describing only what the dataset contains**.
399
+ • Do **not** mention the subtask, relevance, suitability, limitations, or missing information (avoid phrases like “not related,” “does not focus,” “irrelevant,” etc.).
400
+ • After the sentences, output the header **Additionally, here are some external resources you might find helpful:** on a new line. Format your output in markdown as:
401
+ - [Name of Source](URL)
402
+ • Then list 2–3 bullet points, each on its own line, starting with “- ” followed immediately by a URL likely to contain the needed data.
403
+ • No additional commentary.
404
+
405
+
406
+
407
+ General rules:
408
+ Plain text only — no code fences. Markdown link syntax (`[text](url)`) is allowed.
409
+ """
410
+
411
+ rsp = client.chat.completions.create(
412
+ model="gpt-4o",
413
+ messages=[{"role": "user", "content": prompt}],
414
+ temperature=0.3,
415
+ )
416
+ return rsp.choices[0].message.content.strip()
417
+
418
+ # def evaluate_dataset_with_gpt(subtask: str, df: pd.DataFrame,client=openai_client) -> str:
419
+ # # 只选择前 N 个字段,避免超长 token
420
+ # max_columns = 10
421
+ # selected_columns = df.columns[:max_columns]
422
+
423
+ # # 提取字段名及其数据类型
424
+ # column_info = {col: str(df[col].dtype) for col in selected_columns}
425
+
426
+ # # 提取前 3 行示例
427
+ # sample_data = df.head(50)[selected_columns].to_dict(orient="records")
428
+
429
+ # # 构建 prompt
430
+ # prompt = f"""
431
+ # You are a data validation assistant. Your task is to summarize what this dataset represents.
432
+
433
+ # Subtask: {subtask}
434
+
435
+ # Here are the dataset's column names and data types:
436
+ # {json.dumps(column_info, indent=2)}
437
+
438
+ # Here are a few sample rows:
439
+ # {json.dumps(sample_data, indent=2)}
440
+
441
+ # Your response should be concise (2-3 sentences).
442
+ # Focus on the dataset's content and how it might help with the subtask.
443
+ # Do not list column names or describe individual rows.
444
+ # 下面是你的输出格式:
445
+ # 如果你判断数据和data needed相关,那么输出2-3 sentences介绍该数据集。
446
+ # 如果你判断数据和data needed不相关,那么输出2-4条外部资源的链接。
447
+ # """
448
+ # # 调用 GPT-4o
449
+ # rsp = client.chat.completions.create(
450
+ # model="gpt-4o",
451
+ # messages=[{"role": "user", "content": prompt}],
452
+ # temperature=0.3
453
+ # )
454
+ # return rsp.choices[0].message.content.strip()
455
+
456
+
457
+
458
+
459
+
460
+ def external_resource_recommender(subtask: str, client=openai_client) -> str:
461
+ prompt = f"""
462
+ You are a helpful assistant for researchers. Please recommend 3 reliable and relevant online datasets or websites that can help with the following subtask:
463
+
464
+ "{subtask}"
465
+
466
+ Format your output in markdown as:
467
+ - [Name of Source](URL)
468
+ - [Name of Source](URL)
469
+ - [Name of Source](URL)
470
+ """
471
+ rsp = client.chat.completions.create(
472
+ model="gpt-4o",
473
+ messages=[{"role": "user", "content": prompt}],
474
+ temperature=0.3
475
+ )
476
+ return rsp.choices[0].message.content.strip()
477
+
478
+
479
+
480
+ def fallback_query_router(subtask: str, driver) -> pd.DataFrame:
481
+ text = subtask.lower()
482
+
483
+ with driver.session() as session:
484
+
485
+ # --- 1. 物种“where…observed/found…” ---
486
+ if "where" in text and ("observed" in text or "found" in text):
487
+ query = """
488
+ MATCH (o:Observation)-[:OBSERVED_ORGANISM]->(s:Species)
489
+ RETURN s.name AS species, o.site_name AS location, o.date AS date
490
+ ORDER BY o.date DESC
491
+ """
492
+
493
+ # --- 2. before / after 某一年 ---
494
+ elif "before" in text or "after" in text:
495
+ years = re.findall(r'\b(19|20)\d{2}\b', text)
496
+ if years:
497
+ op = "<" if "before" in text else ">"
498
+ query = f"""
499
+ MATCH (o:Observation)-[:OBSERVED_ORGANISM]->(s:Species)
500
+ WHERE o.date {op} date('{years[0]}-01-01')
501
+ RETURN s.name AS species, o.site_name AS location, o.date AS date
502
+ ORDER BY o.date DESC
503
+ """
504
+ else:
505
+ query = "RETURN 1"
506
+
507
+ # --- 3. 飓风相关 ---
508
+ elif "hurricane" in text:
509
+ query = """
510
+ MATCH (o:Observation)-[:OBSERVED_AT]->(h:Hurricane),
511
+ (o)-[:OBSERVED_ORGANISM]->(s:Species),
512
+ (o)-[:OBSERVED_IN]->(site)-[:BELONGS_TO]->(c:County)-[:IN_STATE]->(st:State)
513
+ WHERE st.name = 'Florida'
514
+ RETURN h.name AS hurricane,
515
+ s.name AS species,
516
+ site.name AS site,
517
+ c.name AS county,
518
+ o.date AS date
519
+ ORDER BY o.date DESC
520
+ """
521
+
522
+ # --- 4. 捕食 / predator ---
523
+ elif "preys on" in text or "predator" in text:
524
+ query = """
525
+ MATCH (s1:Species)-[:preysOn]->(s2:Species)
526
+ RETURN s1.name AS predator, s2.name AS prey
527
+ """
528
+
529
+ # --- 5. 默认兜底 ---
530
+ else:
531
+ query = """
532
+ MATCH (o:Observation)
533
+ RETURN o.animal_name AS species, o.site_name AS location, o.date AS date
534
+ """
535
+
536
+ # --- 执行查询 ---
537
+ result = session.run(query)
538
+ df = pd.DataFrame(result.data())
539
+
540
+ if df.empty:
541
+ st.info("🌐 I couldn't find relevant data in KN‑Wildlife. Let me check external sources for you...")
542
+ suggestions = external_resource_recommender(subtask)
543
+ st.markdown(suggestions)
544
+
545
+ return df
546
+
547
+
548
+ def save_dataset(df: pd.DataFrame, filename: str) -> str:
549
+ if len(df) < 10:
550
+ st.warning(f"❌ Dataset too small to save: only {len(df)} rows.")
551
+ return ""
552
+ os.makedirs("saved_datasets", exist_ok=True)
553
+ path = f"saved_datasets/{filename}.csv"
554
+ if os.path.exists(path):
555
+ old_hash = hashlib.md5(open(path, 'rb').read()).hexdigest()
556
+ new_hash = hashlib.md5(df.to_csv(index=False).encode()).hexdigest()
557
+ if old_hash == new_hash:
558
+ st.info(f"ℹ️ Dataset saved: {filename}.csv")
559
+ return path
560
+ df.to_csv(path, index=False)
561
+ st.info(f"✅ Dataset saved: {filename}.csv")
562
+ return path
563
+ # ===================== CHART SUGGESTION (MODIFIED MAP SECTION) =====================
564
+
565
+ def suggest_charts_with_gpt(df: pd.DataFrame) -> str:
566
+ """Generate Streamlit chart code for automatic visualisation."""
567
+ try:
568
+ # st.write("🟢 COLS‑DEBUG:", list(df.columns))
569
+
570
+ # Ensure dates are parsed
571
+ if "date" in df.columns:
572
+ df["date"] = df["date"].apply(lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) == 1 else x)
573
+ df["date"] = pd.to_datetime(df["date"], errors="coerce")
574
+
575
+ if "animal_name" in df.columns and "species" not in df.columns:
576
+ df["species"] = df["animal_name"]
577
+
578
+ df.rename(columns={"latitudes": "latitude", "longitudes": "longitude"}, inplace=True)
579
+
580
+ chart_code = """
581
+ # --- Species Bar Chart ---
582
+ if 'species' in df.columns:
583
+ st.markdown('📊 Count of Observations by Species')
584
+ try:
585
+ species_counts = df['species'].astype(str).value_counts()
586
+ st.bar_chart(species_counts)
587
+ except Exception as e:
588
+ st.warning(f'⚠️ Could not render species chart: {e}')
589
+
590
+ # --- Timeline Line Chart ---
591
+ if 'date' in df.columns:
592
+ st.markdown('📈 Observations Over Time')
593
+ try:
594
+ timeline = df['date'].dropna().value_counts().sort_index()
595
+ st.line_chart(timeline)
596
+ except Exception as e:
597
+ st.warning(f'⚠️ Could not render date chart: {e}')
598
+
599
+ # --- Map Visualisation (highlight all points) ---
600
+ if 'latitude' in df.columns and 'longitude' in df.columns:
601
+ st.markdown('🗺️ Observation Locations on Map')
602
+ try:
603
+ coords = (
604
+ df[['latitude', 'longitude']]
605
+ .apply(pd.to_numeric, errors='coerce')
606
+ .dropna()
607
+ .rename(columns={'latitude': 'lat', 'longitude': 'lon'})
608
+ )
609
+ coords = coords[
610
+ (coords['lat'].between(-90, 90)) &
611
+ (coords['lon'].between(-180, 180))
612
+ ]
613
+ if len(coords) == 0:
614
+ st.warning('⚠️ No valid coordinates to plot on the map.')
615
+ else:
616
+ # ---------- ① 视图 ----------
617
+ try:
618
+ vs_tmp = pdk.data_utils.compute_view(coords[['lon', 'lat']])
619
+ view_state = (
620
+ pdk.ViewState(**vs_tmp, pitch=0, bearing=0)
621
+ if isinstance(vs_tmp, dict) else vs_tmp
622
+ )
623
+ view_state.pitch = 0
624
+ view_state.bearing = 0
625
+ except Exception:
626
+ view_state = pdk.ViewState(
627
+ latitude=coords['lat'].mean(),
628
+ longitude=coords['lon'].mean(),
629
+ zoom=5,
630
+ pitch=0,
631
+ bearing=0,
632
+ )
633
+
634
+ # ---------- ② 高亮层 ----------
635
+ layer = pdk.Layer(
636
+ 'ScatterplotLayer',
637
+ data=coords,
638
+ get_position='[lon, lat]',
639
+ get_radius=50000,
640
+ get_fill_color=[0, 255, 0, 200],
641
+ get_line_color=[255, 255, 255],
642
+ line_width_units='pixels',
643
+ get_line_width=2,
644
+ pickable=True,
645
+ auto_highlight=True,
646
+ )
647
+
648
+ # ---------- ③ 组合 Deck ----------
649
+ deck = pdk.Deck(
650
+ layers=[layer],
651
+ initial_view_state=view_state,
652
+ map_style='mapbox://styles/mapbox/light-v11',
653
+ tooltip={'html': '<b>Lat:</b> {lat}<br/><b>Lon:</b> {lon}'},
654
+ )
655
+ st.pydeck_chart(deck)
656
+ except Exception as e:
657
+ st.warning(f'⚠️ Could not render map: {e}')
658
+ """
659
+ return textwrap.dedent(chart_code)
660
+ except Exception as outer_error:
661
+ return f"st.warning('❌ Chart generation failed: {outer_error}')"
662
+
663
+
664
+
665
+
666
+ # ========= UI layout and connection ==========
667
+ if "chat_history" not in st.session_state:
668
+ st.session_state.chat_history = []
669
+
670
+ # st.set_page_config(
671
+ # page_title="🔬 Explainable Multi-Agent BioData Constructor",
672
+ # layout="centered",
673
+ # initial_sidebar_state="collapsed"
674
+ # )
675
+
676
+ # ——— 自定义主容器最大宽度 ———
677
+ st.markdown(
678
+ """
679
+ <style>
680
+ /* 针对正文文字 */
681
+ html, body, .block-container, .markdown-text-container {
682
+ font-size: 19px !important; /* ← 这里改数字 */
683
+ line-height: 1.6 !important;
684
+ }
685
+ /* 把默认窄屏的 max-width(约700px)改成 1400px,视需要可调整 */
686
+ .block-container {
687
+ max-width: 1600px;
688
+ }
689
+ </style>
690
+ """,
691
+ unsafe_allow_html=True
692
+ )
693
+
694
+ st.title("🔬 EcoData collector")
695
+
696
+
697
+ st.success("""
698
+ 👋 Hi there! I’m **Lily**, your research assistant bot 🤖. I’m here to help you explore data sources related to your **complex research question**. Let’s work together to find the information you need!
699
+
700
+ 💡 You can start by entering a research question like:
701
+
702
+ - *In Florida, how do hurricanes affect the distribution of snakes?*
703
+ - *How does precipitation impact salmon abundance in freshwater ecosystems?*
704
+ - *How do climate change and urbanization jointly affect bird migration and diversity in Florida?*
705
+ """)
706
+
707
+ if driver:
708
+ st.success("🟢 Connected to **KN-Wildlife** — a Neo4j-powered biodiversity graph focused on Florida’s species and ecosystems. I’ll start by checking what relevant data we already have in KN-Wildlife to support your research.")
709
+
710
+ else:
711
+ st.error("🔴 Failed to connect to KN-Wildlife! Please fix connection first.")
712
+ st.stop()
713
+
714
+ question = st.text_area("Enter your research question:", "")
715
+
716
+ # 初始化状态变量
717
+ if "start_clicked" not in st.session_state:
718
+ st.session_state.start_clicked = False
719
+ if "subtask_plan" not in st.session_state:
720
+ st.session_state.subtask_plan = ""
721
+ if "ready_to_continue" not in st.session_state:
722
+ st.session_state.ready_to_continue = False
723
+ if "stop_requested" not in st.session_state:
724
+ st.session_state.stop_requested = False
725
+ if "visualization_ready" not in st.session_state:
726
+ st.session_state.visualization_ready = False
727
+ if "do_visualize" not in st.session_state:
728
+ st.session_state.do_visualize = False
729
+ if "all_dataframes" not in st.session_state:
730
+ st.session_state.all_dataframes = []
731
+ if "retrieval_done" not in st.session_state:
732
+ st.session_state.retrieval_done = False
733
+
734
+ # 点击按钮,触发子任务分解
735
+ if st.button("Let’s start") and question.strip():
736
+ st.session_state.start_clicked = True
737
+ st.session_state.subtask_plan = planner_agent(question)
738
+ st.session_state.ready_to_continue = False
739
+ st.session_state.stop_requested = False
740
+ st.session_state.visualization_ready = False
741
+ st.session_state.do_visualize = False
742
+ st.session_state.all_dataframes = []
743
+ st.session_state.retrieval_done = False
744
+
745
+ # 阶段一:展示子任务
746
+ if st.session_state.start_clicked:
747
+ # st.success("🧠 Now, I’ll break down your research question into several focused subtasks.")
748
+ st.success("🧠 I’ve identified the distinct datasets you’ll need for this research question.")
749
+ with st.expander("🔹 Curious how I split your question? Click to see!", expanded=True):
750
+ st.write(st.session_state.subtask_plan)
751
+
752
+ st.success("📌 I’m ready to roll up my sleeves — shall I start finding datasets for each subtask? 🕒 This step might take a little while, so thanks for your patience!")
753
+
754
+ col1, col2 = st.columns([1, 1])
755
+ with col1:
756
+ if st.button("✅ Yes, go ahead", key="confirm_button"):
757
+ st.session_state.ready_to_continue = True
758
+ st.session_state.stop_requested = False
759
+ with col2:
760
+ if st.button("⛔ No, stop here", key="stop_button"):
761
+ st.session_state.ready_to_continue = False
762
+ st.session_state.stop_requested = True
763
+
764
+
765
+ # ---------- 阶段二:数据检索 & 渲染 ----------
766
+ if st.session_state.ready_to_continue:
767
+
768
+ # ① 先确定 Planner 输出使用的前缀
769
+ # 这里假设只有两种可能:Subtask / Dataset Need
770
+ if "Dataset Need" in st.session_state.subtask_plan:
771
+ prefix = "Dataset Need"
772
+ else:
773
+ prefix = "Subtask"
774
+
775
+ # ② 用 f-string 拼正则(rf = raw‑formatted)
776
+ pattern = rf"{prefix} \d+:.*?(?={prefix} \d+:|$)"
777
+ subtasks = re.findall(pattern,
778
+ st.session_state.subtask_plan,
779
+ flags=re.DOTALL)
780
+
781
+ # 如果 Planner 没输出任何块,给个提示
782
+ if not subtasks:
783
+ st.warning("⚠️ No dataset blocks detected in planner output.")
784
+ st.stop()
785
+
786
+ # 检索只执行一次
787
+ if not st.session_state.retrieval_done: # ★
788
+ progress_bar = st.progress(0)
789
+ total = len(subtasks)
790
+ saved_hashes = set()
791
+ st.session_state.all_dataframes = []
792
+
793
+
794
+ for idx, subtask in enumerate(subtasks):
795
+ # with st.expander(f"🔹 Retrieving data for subtask {idx+1}:", expanded=True):
796
+ with st.expander(f"🔹 Retrieving data for dataset need {idx+1}:", expanded=True):
797
+ cleaned_subtask = "\n".join(subtask.strip().split("\n")[1:])
798
+ st.markdown(cleaned_subtask)
799
+
800
+ # ---------- 首次运行:真正检索 ----------
801
+ if not st.session_state.retrieval_done: # ★
802
+ df = intelligent_retriever_agent(subtask, saved_hashes)
803
+
804
+ if not df.empty:
805
+ df_hash = hashlib.md5(df.to_csv(index=False).encode()).hexdigest()
806
+ if df_hash in saved_hashes:
807
+ st.warning("⚠️ This dataset has already been saved — skipping duplicate.")
808
+ elif len(df) < 10:
809
+ st.warning(f"❌ This dataset is too small — just {len(df)} rows. Skipping save.")
810
+ else:
811
+ saved_hashes.add(df_hash)
812
+ df = flatten_props(df)
813
+ df = standardize_latlon(df)
814
+ summary = evaluate_dataset_with_gpt(subtask, df)
815
+ st.session_state.all_dataframes.append({
816
+ "hash": df_hash,
817
+ "df": df,
818
+ "summary": summary
819
+ })
820
+ st.dataframe(df.head(50))
821
+ save_path = save_dataset(df, f"subtask_{idx+1}")
822
+ if save_path:
823
+ # summary = evaluate_dataset_with_gpt(subtask, df)
824
+ st.markdown("**📝 Dataset Introduction:**")
825
+ st.write(summary)
826
+ if 'progress_bar' in locals():
827
+ progress_bar.progress((idx + 1) / total)
828
+
829
+ # ---------- 之后 rerun:只展示 ----------
830
+ else: # ★
831
+ if idx < len(st.session_state.all_dataframes):
832
+ # _hash, df = st.session_state.all_dataframes[idx]
833
+ # df = standardize_latlon(df)
834
+ # st.dataframe(df.head(50))
835
+ entry = st.session_state.all_dataframes[idx] # ➕ 新行
836
+ df = standardize_latlon(entry["df"])
837
+ st.dataframe(df.head(50))
838
+ st.write(entry.get("summary", ""))
839
+
840
+ # 检索完成后打标记
841
+ if not st.session_state.retrieval_done: # ★
842
+ st.session_state.retrieval_done = True
843
+ st.session_state.visualization_ready = bool(st.session_state.all_dataframes)
844
+
845
+
846
+
847
+ if st.session_state.all_dataframes:
848
+ st.session_state.visualization_ready = True
849
+ else:
850
+ st.success("🎉 All subtasks completed and datasets generated!💡 Feel free to ask me more questions anytime!")
851
+ # st.success("🎉 All subtasks completed and datasets generated!")
852
+ # st.success("💡 Feel free to ask Lily more questions anytime!")
853
+
854
+ # 阶段三:是否进行可视化选择
855
+ if st.session_state.visualization_ready and not st.session_state.do_visualize:
856
+ st.success("📊 All set! I’ve gathered the datasets. Ready to visualize them?")
857
+
858
+ col1, col2 = st.columns([1, 1])
859
+ with col1:
860
+ if st.button("✅ Yes, go ahead", key="viz_confirm"):
861
+ st.session_state.do_visualize = True
862
+ with col2:
863
+ if st.button("⛔ No, stop here", key="viz_stop"):
864
+ st.session_state.visualization_ready = False
865
+ st.success("🎉 All subtasks completed and datasets generated!💡 Feel free to ask me more questions anytime!")
866
+ # st.success("🎉 All subtasks completed and datasets generated!")
867
+ # st.success("💡 Feel free to ask Lily more questions anytime!")
868
+
869
+ # 阶段三:数据可视化
870
+ if st.session_state.do_visualize:
871
+ for i, entry in enumerate(st.session_state.all_dataframes):
872
+ df = entry["df"]
873
+ summary = entry.get("summary", "")
874
+ if len(df) < 10:
875
+ continue
876
+ with st.expander(f"**🔹 Dataset {i + 1} Visualization**", expanded=True):
877
+ st.markdown(f"Dataset {i + 1} Preview")
878
+ st.dataframe(df.head(10))
879
+ chart_code = suggest_charts_with_gpt(df)
880
+ if chart_code:
881
+ # st.markdown("🧠 The visualization code:")
882
+ # st.code(chart_code, language="python")
883
+ try:
884
+ exec(chart_code, {"st": st, "pd": pd, "df": df, "pdk": pdk})
885
+ except Exception as e:
886
+ st.error(f"❌ Error running chart code: {e}")
887
+
888
+ st.success("🎉 All subtasks completed and datasets generated!💡 Feel free to ask me more questions anytime!")
889
+ # st.success("💡 Feel free to ask me more questions anytime!")
890
+
891
+ if st.session_state.stop_requested:
892
+ st.info("👍 No problem! You can review the subtasks above or revise your question.")
893
+
894
+
895
+
896
+ # —— 在侧边栏插入 ChatGPT 风格聊天面板 ——
897
+ with st.sidebar.expander("💬 Chat with Lily", expanded=True):
898
+ # 聊天输入框
899
+ user_msg = st.chat_input("Type your question here…", key="sidebar_chat_input")
900
+ if user_msg:
901
+ # 拼当前页面上下文
902
+ context_parts = []
903
+ if st.session_state.subtask_plan:
904
+ context_parts.append("Subtasks:\n" + st.session_state.subtask_plan)
905
+ for entry in st.session_state.all_dataframes:
906
+ context_parts.append("Data summary:\n" + entry["summary"])
907
+ page_context = "\n\n".join(context_parts)
908
+
909
+ # 调用 GPT helper
910
+ with st.spinner("Lily is thinking…"):
911
+ assistant_msg = gpt_chat(
912
+ sys_msg=f"You are Lily, a research assistant. Here’s what’s on screen:\n\n{page_context}",
913
+ user_msg=user_msg
914
+ )
915
+
916
+ # 保存对话
917
+ st.session_state.chat_history.append({"role": "user", "content": user_msg})
918
+ st.session_state.chat_history.append({"role": "assistant", "content": assistant_msg})
919
+
920
+ # 渲染历史对话
921
+ for msg in st.session_state.chat_history:
922
+ if msg["role"] == "user":
923
+ st.chat_message("user").write(msg["content"])
924
+ else:
925
+ st.chat_message("assistant").write(msg["content"])
requirements.txt CHANGED
@@ -1,3 +1,11 @@
1
- altair
 
2
  pandas
3
- streamlit
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ openai
3
  pandas
4
+ numpy
5
+ torch
6
+ scikit-learn
7
+ faiss-cpu
8
+ pydeck
9
+ transformers==4.35.2
10
+ neo4j
11
+ requests