|
|
import streamlit as st |
|
|
import folium |
|
|
from streamlit_folium import st_folium |
|
|
st.set_page_config( |
|
|
page_title="🔬 Explainable Multi-Agent BioData Constructor", |
|
|
layout="centered", |
|
|
initial_sidebar_state="collapsed" |
|
|
) |
|
|
from neo4j import GraphDatabase |
|
|
import openai |
|
|
import pandas as pd |
|
|
import os |
|
|
import re |
|
|
import hashlib |
|
|
import json |
|
|
import pydeck as pdk |
|
|
import faiss |
|
|
import numpy as np |
|
|
from sklearn.preprocessing import normalize |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
import torch |
|
|
import ast |
|
|
import textwrap |
|
|
import requests |
|
|
|
|
|
NEO4J_URI = os.getenv("NEO4J_URI") |
|
|
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME") |
|
|
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD") |
|
|
openai_api_key = os.getenv("openai_api_key") |
|
|
|
|
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" |
|
|
|
|
|
def download_if_missing(url, local_path): |
|
|
if not os.path.exists(local_path): |
|
|
with open(local_path, "wb") as f: |
|
|
f.write(requests.get(url).content) |
|
|
|
|
|
base_url = "https://github.com/Tianyu-yang-anna/EcoData-collector/releases/download/v1.0" |
|
|
files = { |
|
|
"nodes.csv": "/tmp/nodes.csv", |
|
|
"nodes_embeddings.npy": "/tmp/nodes_embeddings.npy", |
|
|
"relationships.csv": "/tmp/relationships.csv", |
|
|
"relationships_embeddings.npy": "/tmp/relationships_embeddings.npy" |
|
|
} |
|
|
|
|
|
for fname, path in files.items(): |
|
|
download_if_missing(f"{base_url}/{fname}", path) |
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
|
def create_driver(): |
|
|
try: |
|
|
driver = GraphDatabase.driver( |
|
|
NEO4J_URI, |
|
|
auth=(NEO4J_USERNAME, NEO4J_PASSWORD) |
|
|
) |
|
|
with driver.session() as session: |
|
|
session.run("RETURN 1") |
|
|
return driver |
|
|
except Exception as e: |
|
|
st.error(f"🔴 Neo4j connection failed: {e}") |
|
|
return None |
|
|
|
|
|
driver = create_driver() |
|
|
|
|
|
openai_client = openai.OpenAI(api_key=openai_api_key) |
|
|
|
|
|
def gpt_chat(sys_msg: str, user_msg: str, **kwargs): |
|
|
rsp = openai_client.chat.completions.create( |
|
|
model="gpt-4o", |
|
|
messages=[{"role": "system", "content": sys_msg}, {"role": "user", "content": user_msg}], |
|
|
**kwargs |
|
|
) |
|
|
return rsp.choices[0].message.content.strip() |
|
|
|
|
|
|
|
|
class SimpleEncoder: |
|
|
def __init__(self): |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.tokenizer = AutoTokenizer.from_pretrained("/app/model") |
|
|
self.model = AutoModel.from_pretrained("/app/model").to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
def encode(self, texts, batch_size: int = 16): |
|
|
embeddings = [] |
|
|
for i in range(0, len(texts), batch_size): |
|
|
batch = texts[i : i + batch_size] |
|
|
with torch.no_grad(): |
|
|
inputs = self.tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(self.device) |
|
|
outputs = self.model(**inputs) |
|
|
batch_emb = outputs.last_hidden_state.mean(dim=1).cpu().numpy() |
|
|
embeddings.append(batch_emb) |
|
|
return np.vstack(embeddings) |
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
|
def get_encoder(): |
|
|
return SimpleEncoder() |
|
|
|
|
|
|
|
|
csv_file_pairs = [ |
|
|
("/tmp/nodes.csv", "/tmp/nodes_embeddings.npy"), |
|
|
("/tmp/relationships.csv", "/tmp/relationships_embeddings.npy"), |
|
|
] |
|
|
|
|
|
for csv_path, npy_path in csv_file_pairs: |
|
|
if not os.path.exists(npy_path): |
|
|
st.error(f"❌ Embedding file not found: {npy_path}") |
|
|
st.stop() |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
|
def load_embeddings_and_faiss_indexes(file_pairs): |
|
|
index_list, metadatas = [], [] |
|
|
for csv_path, npy_path in file_pairs: |
|
|
try: |
|
|
df = pd.read_csv(csv_path).fillna("") |
|
|
emb = np.load(npy_path).astype("float32") |
|
|
index = faiss.IndexFlatIP(emb.shape[1]) |
|
|
if faiss.get_num_gpus() > 0: |
|
|
res = faiss.StandardGpuResources() |
|
|
index = faiss.index_cpu_to_gpu(res, 0, index) |
|
|
index.add(emb) |
|
|
index_list.append(index) |
|
|
metadatas.append(df) |
|
|
except Exception as e: |
|
|
st.warning(f"⚠️ Failed to load {csv_path} / {npy_path}: {e}") |
|
|
index_list.append(None) |
|
|
metadatas.append(pd.DataFrame()) |
|
|
return index_list, metadatas |
|
|
|
|
|
csv_faiss_indexes, csv_metadatas = load_embeddings_and_faiss_indexes(csv_file_pairs) |
|
|
|
|
|
|
|
|
|
|
|
def flatten_props(df: pd.DataFrame) -> pd.DataFrame: |
|
|
if "props" not in df.columns: |
|
|
return df |
|
|
try: |
|
|
props_df = df["props"].apply(ast.literal_eval).apply(pd.Series) |
|
|
out = pd.concat([df.drop(columns=["props"]), props_df], axis=1) |
|
|
|
|
|
return out |
|
|
except Exception as e: |
|
|
st.warning(f"⚠️ Failed to parse props column: {e}") |
|
|
return df |
|
|
|
|
|
def unpack_singletons(df: pd.DataFrame) -> pd.DataFrame: |
|
|
for col in df.columns: |
|
|
if df[col].apply(lambda x: isinstance(x, (list, tuple)) and len(x) == 1).any(): |
|
|
df[col] = df[col].apply(lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) == 1 else x) |
|
|
return df |
|
|
|
|
|
def standardize_latlon(df: pd.DataFrame) -> pd.DataFrame: |
|
|
""" |
|
|
- 统一列名到 latitudes / longitudes |
|
|
- 若出现同名重复列,保留第一列并删除其余 |
|
|
- longitudes 位置保持不动,把 latitudes 放到其右侧 |
|
|
""" |
|
|
|
|
|
col_map = {} |
|
|
for col in df.columns: |
|
|
low = col.lower() |
|
|
if "lat" in low and "lon" not in low: |
|
|
col_map[col] = "latitudes" |
|
|
elif ("lon" in low or "lng" in low): |
|
|
col_map[col] = "longitudes" |
|
|
df = df.rename(columns=col_map) |
|
|
|
|
|
|
|
|
|
|
|
while df.columns.duplicated().any(): |
|
|
dup_col = df.columns[df.columns.duplicated()][0] |
|
|
|
|
|
first_idx = list(df.columns).index(dup_col) |
|
|
keep = [True] * len(df.columns) |
|
|
for i, c in enumerate(df.columns): |
|
|
if c == dup_col and i != first_idx: |
|
|
keep[i] = False |
|
|
df = df.loc[:, keep] |
|
|
|
|
|
|
|
|
for c in ("latitudes", "longitudes"): |
|
|
if c in df.columns and not isinstance(df[c], pd.Series): |
|
|
|
|
|
df[c] = df[c].iloc[:, 0] |
|
|
if c in df.columns: |
|
|
df[c] = df[c].apply( |
|
|
lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) == 1 else x |
|
|
) |
|
|
df[c] = pd.to_numeric(df[c], errors="coerce") |
|
|
|
|
|
|
|
|
if {"longitudes", "latitudes"}.issubset(df.columns): |
|
|
cols = list(df.columns) |
|
|
lon_idx = cols.index("longitudes") |
|
|
lat_idx = cols.index("latitudes") |
|
|
if lat_idx != lon_idx + 1: |
|
|
cols.pop(lat_idx) |
|
|
cols.insert(lon_idx + 1, "latitudes") |
|
|
df = df[cols] |
|
|
|
|
|
return df |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@st.cache_data(show_spinner=False) |
|
|
def rag_csv_fallback(subtask, top_k=2000): |
|
|
encoder = get_encoder() |
|
|
query_vec = encoder.encode([subtask]) |
|
|
query_vec = normalize(query_vec, axis=1).astype("float32") |
|
|
if not np.any(query_vec): |
|
|
return pd.DataFrame() |
|
|
all_results = [] |
|
|
for index, metadata in zip(csv_faiss_indexes, csv_metadatas): |
|
|
if index is None or metadata.empty: |
|
|
continue |
|
|
distances, indices = index.search(query_vec, top_k) |
|
|
retrieved = metadata.iloc[indices[0]].copy() |
|
|
all_results.append(retrieved) |
|
|
if all_results: |
|
|
return pd.concat(all_results).drop_duplicates().reset_index(drop=True) |
|
|
return pd.DataFrame() |
|
|
|
|
|
|
|
|
|
|
|
def generate_cypher_with_gpt(subtask: str) -> str: |
|
|
prompt = f""" |
|
|
You are an expert Cypher query generator for a Neo4j biodiversity database. The schema is as follows: |
|
|
|
|
|
Node Types and Properties: |
|
|
- Observation: animal_name, date, latitude, longitude |
|
|
- Species: name, species_full_name |
|
|
- Site: name |
|
|
- County: name |
|
|
- State: name |
|
|
- Hurricane: name |
|
|
- Policy: title, description |
|
|
- ClimateEvent: event_type, date |
|
|
- TemperatureReading: value, date, location |
|
|
- Precipitation: amount, date, location |
|
|
|
|
|
Relationship Types: |
|
|
- OBSERVED_IN: (Observation)-[:OBSERVED_IN]->(Site) |
|
|
- OBSERVED_ORGANISM: (Observation)-[:OBSERVED_ORGANISM]->(Species) |
|
|
- BELONGS_TO: (Site)-[:BELONGS_TO]->(County) |
|
|
- IN_COUNTY: (Observation)-[:IN_COUNTY]->(County) |
|
|
- IN_STATE: (County)-[:IN_STATE]->(State) |
|
|
- interactsWith: (Species)-[:interactsWith]->(Species) |
|
|
- preysOn: (Species)-[:preysOn]->(Species) |
|
|
|
|
|
Your task is to generate a **precise and efficient** Cypher query for the following subtask: |
|
|
"{subtask}" |
|
|
|
|
|
Guidelines: |
|
|
- Do NOT return all nodes of a type (e.g., all Species) unless the subtask explicitly asks for it. |
|
|
- If a location (county/state) is mentioned or implied, include location filtering using IN_COUNTY, IN_STATE, or BELONGS_TO. |
|
|
- If the subtask implies a taxonomic or common name group (e.g., frog, snake, salmon), apply CONTAINS or STARTS WITH filters on Species.name or species_full_name, using toLower(...) for case-insensitive matching. |
|
|
- If the subtask includes a time range, include date filtering. |
|
|
- Prefer using DISTINCT to avoid redundant results. |
|
|
- Only return fields that are clearl y needed to fulfill the subtask. |
|
|
|
|
|
Return your response strictly as a **JSON object** with the following fields: |
|
|
- "intent": a short description of what the query does |
|
|
- "cypher_query": the Cypher query |
|
|
- "fields": a list of returned field names (e.g., ["species", "county", "date"]) |
|
|
|
|
|
Do not include any explanation or commentary—only return the JSON object. |
|
|
""" |
|
|
|
|
|
|
|
|
client = openai.OpenAI(api_key=os.getenv("openai_api_key")) |
|
|
response = client.chat.completions.create( |
|
|
model="gpt-4o", |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
temperature=0 |
|
|
) |
|
|
content = response.choices[0].message.content.strip() |
|
|
content = re.sub(r"^(json|python)?", "", content, flags=re.IGNORECASE).strip() |
|
|
content = re.sub(r"$", "", content).strip() |
|
|
|
|
|
try: |
|
|
cypher_json = json.loads(content) |
|
|
return cypher_json["cypher_query"] |
|
|
except Exception as e: |
|
|
return "" |
|
|
|
|
|
|
|
|
def intelligent_retriever_agent(subtask, saved_hashes=None): |
|
|
if saved_hashes is None: |
|
|
saved_hashes = set() |
|
|
st.success("🔍 Attempting to retrieve data from the Ecodata knowledge graph…") |
|
|
cypher_query = generate_cypher_with_gpt(subtask) |
|
|
cypher_df = pd.DataFrame() |
|
|
if cypher_query.strip(): |
|
|
st.code(cypher_query, language="cypher") |
|
|
try: |
|
|
query = re.sub(r"(?i)LIMIT\s+\d+\s*$", "", cypher_query) |
|
|
with driver.session() as session: |
|
|
result = session.run(query) |
|
|
cypher_df = pd.DataFrame(result.data()) |
|
|
except Exception as e: |
|
|
st.error(f"🚨 Cypher execution error: {e}") |
|
|
st.code(query, language="cypher") |
|
|
|
|
|
fallback_needed = False |
|
|
if cypher_df.empty: |
|
|
|
|
|
fallback_needed = True |
|
|
else: |
|
|
df_hash = hashlib.md5(cypher_df.to_csv(index=False).encode()).hexdigest() |
|
|
st.write(f"ℹ️ Cypher rows: {len(cypher_df)} | duplicate?: {df_hash in saved_hashes}") |
|
|
if df_hash in saved_hashes or len(cypher_df) < 10: |
|
|
fallback_needed = True |
|
|
if fallback_needed: |
|
|
csv_df = rag_csv_fallback(subtask) |
|
|
if not csv_df.empty: |
|
|
csv_df = flatten_props(csv_df) |
|
|
csv_df = unpack_singletons(csv_df) |
|
|
csv_df = standardize_latlon(csv_df) |
|
|
|
|
|
return csv_df |
|
|
st.warning("❌ CSV fallback also returned nothing.") |
|
|
return pd.DataFrame() |
|
|
|
|
|
st.success("✅ Cypher query successful. Using Cypher result.") |
|
|
cypher_df = flatten_props(cypher_df) |
|
|
cypher_df = unpack_singletons(cypher_df) |
|
|
cypher_df = standardize_latlon(cypher_df) |
|
|
if "species" not in cypher_df.columns and "animal_name" in cypher_df.columns: |
|
|
cypher_df["species"] = cypher_df["animal_name"] |
|
|
if "date" in cypher_df.columns: |
|
|
cypher_df["date"] = pd.to_datetime(cypher_df["date"], errors="coerce") |
|
|
cypher_df.rename(columns={"latitudes": "latitude", "longitudes": "longitude", "lat": "latitude", "lon": "longitude"}, inplace=True) |
|
|
for col in ("latitude", "longitude"): |
|
|
if col in cypher_df.columns: |
|
|
cypher_df[col] = pd.to_numeric(cypher_df[col], errors="coerce") |
|
|
return cypher_df |
|
|
|
|
|
|
|
|
def planner_agent(question: str) -> str: |
|
|
prompt = f""" |
|
|
You are a **research‑data planning assistant**. |
|
|
|
|
|
------------------------ 📝 TASK ------------------------ |
|
|
Your job is to list the **separate data sets** a researcher must collect |
|
|
to answer the research question below. |
|
|
|
|
|
*Each data set* should be focused on one clearly defined entity or |
|
|
phenomenon (e.g. "Tracks of hurricanes affecting Florida since 1950", |
|
|
"Geo‑tagged snake observations in Florida 2000‑present"). |
|
|
|
|
|
-------------------- 📋 OUTPUT FORMAT -------------------- |
|
|
Write 1–3 blocks. For **each** block use *all* four lines exactly: |
|
|
|
|
|
Dataset Need X: <Concise title, ≤ 10 words> |
|
|
Description: <Why this data matters—1 short sentence> |
|
|
|
|
|
⚠️ Do NOT add extra lines or markdown. |
|
|
⚠️ Keep variable names short; no code blocks; no quotes. |
|
|
|
|
|
-------------------- 🔍 RESEARCH QUESTION -------------------- |
|
|
{question} |
|
|
""" |
|
|
rsp = openai_client.chat.completions.create( |
|
|
model="gpt-4o", |
|
|
messages=[ |
|
|
{"role": "system", "content": "You are an expert research planner."}, |
|
|
{"role": "user", "content": prompt} |
|
|
], |
|
|
temperature=0.2 |
|
|
) |
|
|
return rsp.choices[0].message.content.strip() |
|
|
|
|
|
|
|
|
|
|
|
def evaluate_dataset_with_gpt(subtask: str, df: pd.DataFrame, client=openai_client) -> str: |
|
|
max_columns = 50 |
|
|
selected_cols = df.columns[:max_columns] |
|
|
column_info = {col: str(df[col].dtype) for col in selected_cols} |
|
|
sample_rows = df.head(3)[selected_cols].to_dict(orient="records") |
|
|
|
|
|
prompt = f""" |
|
|
You are a data‑validation assistant. Decide whether the dataset below is useful for the research subtask. |
|
|
|
|
|
===== TASK ===== |
|
|
Subtask: "{subtask}" |
|
|
|
|
|
===== DATASET PREVIEW ===== |
|
|
Schema (first {len(selected_cols)} columns): |
|
|
{json.dumps(column_info, indent=10)} |
|
|
Sample rows (10 max): |
|
|
{json.dumps(sample_rows, indent=10)} |
|
|
|
|
|
===== OUTPUT INSTRUCTIONS (follow strictly) ===== |
|
|
Case A – Relevant: |
|
|
• Write exactly two sentences, each no more than 30 words. |
|
|
• Summarize what the dataset contains and why it helps the subtask. |
|
|
• Do not mention column names or list individual rows. |
|
|
|
|
|
Case B – Not relevant: |
|
|
• Write one or two sentences, each no more than 30 words, **describing only what the dataset contains**. |
|
|
• Do **not** mention the subtask, relevance, suitability, limitations, or missing information (avoid phrases like “not related,” “does not focus,” “irrelevant,” etc.). |
|
|
• After the sentences, output the header **Additionally, here are some external resources you might find helpful:** on a new line. Format your output in markdown as: |
|
|
- [Name of Source](URL) |
|
|
• Then list 2–3 bullet points, each on its own line, starting with “- ” followed immediately by a URL likely to contain the needed data. |
|
|
• No additional commentary. |
|
|
|
|
|
|
|
|
|
|
|
General rules: |
|
|
Plain text only — no code fences. Markdown link syntax (`[text](url)`) is allowed. |
|
|
""" |
|
|
|
|
|
rsp = client.chat.completions.create( |
|
|
model="gpt-4o", |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
temperature=0.3, |
|
|
) |
|
|
return rsp.choices[0].message.content.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def external_resource_recommender(subtask: str, client=openai_client) -> str: |
|
|
prompt = f""" |
|
|
You are a helpful research assistant. Your task is to recommend **three reliable, publicly accessible online datasets or data repositories** that can assist with the following scientific subtask: |
|
|
|
|
|
{subtask} |
|
|
|
|
|
Only include sources that are: |
|
|
- Trusted (e.g., government, academic, or well-established platforms) |
|
|
- Relevant to the topic |
|
|
- Accessible without login when possible |
|
|
|
|
|
Format your answer strictly in markdown: |
|
|
- [Name of Source](URL) |
|
|
- [Name of Source](URL) |
|
|
- [Name of Source](URL) |
|
|
|
|
|
Do not include any explanations or extra text—only the list. |
|
|
""" |
|
|
rsp = client.chat.completions.create( |
|
|
model="gpt-4o", |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
temperature=0.3 |
|
|
) |
|
|
return rsp.choices[0].message.content.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fallback_query_router(subtask: str, driver) -> pd.DataFrame: |
|
|
text = subtask.lower() |
|
|
|
|
|
with driver.session() as session: |
|
|
|
|
|
|
|
|
if "where" in text and ("observed" in text or "found" in text): |
|
|
query = """ |
|
|
MATCH (o:Observation)-[:OBSERVED_ORGANISM]->(s:Species) |
|
|
RETURN s.name AS species, o.site_name AS location, o.date AS date |
|
|
ORDER BY o.date DESC |
|
|
""" |
|
|
|
|
|
|
|
|
elif "before" in text or "after" in text: |
|
|
years = re.findall(r'\b(19|20)\d{2}\b', text) |
|
|
if years: |
|
|
op = "<" if "before" in text else ">" |
|
|
query = f""" |
|
|
MATCH (o:Observation)-[:OBSERVED_ORGANISM]->(s:Species) |
|
|
WHERE o.date {op} date('{years[0]}-01-01') |
|
|
RETURN s.name AS species, o.site_name AS location, o.date AS date |
|
|
ORDER BY o.date DESC |
|
|
""" |
|
|
else: |
|
|
query = "RETURN 1" |
|
|
|
|
|
|
|
|
elif "hurricane" in text: |
|
|
query = """ |
|
|
MATCH (o:Observation)-[:OBSERVED_AT]->(h:Hurricane), |
|
|
(o)-[:OBSERVED_ORGANISM]->(s:Species), |
|
|
(o)-[:OBSERVED_IN]->(site)-[:BELONGS_TO]->(c:County)-[:IN_STATE]->(st:State) |
|
|
WHERE st.name = 'Florida' |
|
|
RETURN h.name AS hurricane, |
|
|
s.name AS species, |
|
|
site.name AS site, |
|
|
c.name AS county, |
|
|
o.date AS date |
|
|
ORDER BY o.date DESC |
|
|
""" |
|
|
|
|
|
|
|
|
elif "preys on" in text or "predator" in text: |
|
|
query = """ |
|
|
MATCH (s1:Species)-[:preysOn]->(s2:Species) |
|
|
RETURN s1.name AS predator, s2.name AS prey |
|
|
""" |
|
|
|
|
|
|
|
|
else: |
|
|
query = """ |
|
|
MATCH (o:Observation) |
|
|
RETURN o.animal_name AS species, o.site_name AS location, o.date AS date |
|
|
""" |
|
|
|
|
|
|
|
|
result = session.run(query) |
|
|
df = pd.DataFrame(result.data()) |
|
|
|
|
|
if df.empty: |
|
|
st.info("🌐 I couldn't find relevant data in KN‑Wildlife. Let me check external sources for you...") |
|
|
suggestions = external_resource_recommender(subtask) |
|
|
st.markdown(suggestions) |
|
|
|
|
|
return df |
|
|
|
|
|
|
|
|
def save_dataset(df: pd.DataFrame, filename: str) -> str: |
|
|
if len(df) < 10: |
|
|
st.warning(f"❌ Dataset too small to save: only {len(df)} rows.") |
|
|
return "" |
|
|
save_dir = "/tmp/saved_datasets" |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
path = f"{save_dir}/{filename}.csv" |
|
|
if os.path.exists(path): |
|
|
old_hash = hashlib.md5(open(path, 'rb').read()).hexdigest() |
|
|
new_hash = hashlib.md5(df.to_csv(index=False).encode()).hexdigest() |
|
|
if old_hash == new_hash: |
|
|
st.info(f"ℹ️ Dataset saved: {filename}.csv") |
|
|
return path |
|
|
df.to_csv(path, index=False) |
|
|
st.info(f"✅ Dataset saved: {filename}.csv") |
|
|
|
|
|
return path |
|
|
|
|
|
|
|
|
def suggest_charts_with_gpt(df: pd.DataFrame) -> str: |
|
|
"""Generate Streamlit chart code for automatic visualisation.""" |
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
if "date" in df.columns: |
|
|
df["date"] = df["date"].apply(lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) == 1 else x) |
|
|
df["date"] = pd.to_datetime(df["date"], errors="coerce") |
|
|
|
|
|
if "animal_name" in df.columns and "species" not in df.columns: |
|
|
df["species"] = df["animal_name"] |
|
|
|
|
|
df.rename(columns={"latitudes": "latitude", "longitudes": "longitude"}, inplace=True) |
|
|
|
|
|
chart_code = """ |
|
|
# --- Species Bar Chart --- |
|
|
if 'species' in df.columns: |
|
|
st.markdown('📊 Count of Observations by Species') |
|
|
try: |
|
|
species_counts = df['species'].astype(str).value_counts() |
|
|
st.bar_chart(species_counts) |
|
|
except Exception as e: |
|
|
st.warning(f'⚠️ Could not render species chart: {e}') |
|
|
|
|
|
# --- Timeline Line Chart --- |
|
|
if 'date' in df.columns: |
|
|
st.markdown('📈 Observations Over Time') |
|
|
try: |
|
|
timeline = df['date'].dropna().value_counts().sort_index() |
|
|
st.line_chart(timeline) |
|
|
except Exception as e: |
|
|
st.warning(f'⚠️ Could not render date chart: {e}') |
|
|
|
|
|
# --- Map Visualisation (highlight all points) --- |
|
|
if 'latitude' in df.columns and 'longitude' in df.columns: |
|
|
st.markdown('🗺️ Observation Locations on Map') |
|
|
try: |
|
|
coords = df[['latitude', 'longitude']].dropna() |
|
|
coords = coords[(coords['latitude'].between(-90, 90)) & (coords['longitude'].between(-180, 180))] |
|
|
|
|
|
if len(coords) == 0: |
|
|
raise Exception('⚠️ No valid coordinates to plot on the map.') |
|
|
else: |
|
|
# 计算中心点 |
|
|
center = [coords['latitude'].mean(), coords['longitude'].mean()] |
|
|
m = folium.Map(location=center, zoom_start=5) |
|
|
|
|
|
# 添加散点 |
|
|
for _, row in coords.iterrows(): |
|
|
folium.CircleMarker( |
|
|
location=[row['latitude'], row['longitude']], |
|
|
radius=5, |
|
|
color='green', |
|
|
fill=True, |
|
|
fill_color='green', |
|
|
fill_opacity=0.7, |
|
|
).add_to(m) |
|
|
|
|
|
st_folium(m, width=700, height=500) |
|
|
except Exception as e: |
|
|
st.warning(f'⚠️ Could not render map: {e}') |
|
|
""" |
|
|
return textwrap.dedent(chart_code) |
|
|
except Exception as outer_error: |
|
|
return f"st.warning('❌ Chart generation failed: {outer_error}')" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "chat_history" not in st.session_state: |
|
|
st.session_state.chat_history = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.markdown( |
|
|
""" |
|
|
<style> |
|
|
/* 针对正文文字 */ |
|
|
html, body, .block-container, .markdown-text-container { |
|
|
font-size: 19px !important; /* ← 这里改数字 */ |
|
|
line-height: 1.6 !important; |
|
|
} |
|
|
/* 把默认窄屏的 max-width(约700px)改成 1400px,视需要可调整 */ |
|
|
.block-container { |
|
|
max-width: 1600px; |
|
|
} |
|
|
</style> |
|
|
""", |
|
|
unsafe_allow_html=True |
|
|
) |
|
|
|
|
|
st.title("🐾 Quest2DataAgent_EcoData") |
|
|
|
|
|
|
|
|
st.success(""" |
|
|
👋 Hi there! I’m **Lily**, your research assistant bot 🤖. I’m here to help you explore data sources related to your **complex research question**. Let’s work together to find the information you need! |
|
|
|
|
|
💡 You can start by entering a research question like: |
|
|
|
|
|
- *In Florida, how do hurricanes affect the distribution of snakes?* |
|
|
- *How does precipitation impact salmon abundance in freshwater ecosystems?* |
|
|
- *How do climate change and urbanization jointly affect bird migration and diversity in Florida?* |
|
|
""") |
|
|
|
|
|
if driver: |
|
|
st.success("🟢 Connected to **Ecodata** — a Neo4j-powered biodiversity graph focused on species and ecosystems. I’ll start by checking what relevant data we already have in Ecodata to support your research.") |
|
|
|
|
|
else: |
|
|
st.error("🔴 Failed to connect to Ecodata! Please fix connection first.") |
|
|
st.stop() |
|
|
|
|
|
question = st.text_area("Enter your research question:", "") |
|
|
|
|
|
|
|
|
if "start_clicked" not in st.session_state: |
|
|
st.session_state.start_clicked = False |
|
|
if "subtask_plan" not in st.session_state: |
|
|
st.session_state.subtask_plan = "" |
|
|
if "ready_to_continue" not in st.session_state: |
|
|
st.session_state.ready_to_continue = False |
|
|
if "stop_requested" not in st.session_state: |
|
|
st.session_state.stop_requested = False |
|
|
if "visualization_ready" not in st.session_state: |
|
|
st.session_state.visualization_ready = False |
|
|
if "do_visualize" not in st.session_state: |
|
|
st.session_state.do_visualize = False |
|
|
if "all_dataframes" not in st.session_state: |
|
|
st.session_state.all_dataframes = [] |
|
|
if "retrieval_done" not in st.session_state: |
|
|
st.session_state.retrieval_done = False |
|
|
|
|
|
|
|
|
if st.button("Let’s start") and question.strip(): |
|
|
st.session_state.start_clicked = True |
|
|
st.session_state.subtask_plan = planner_agent(question) |
|
|
st.session_state.ready_to_continue = False |
|
|
st.session_state.stop_requested = False |
|
|
st.session_state.visualization_ready = False |
|
|
st.session_state.do_visualize = False |
|
|
st.session_state.all_dataframes = [] |
|
|
st.session_state.retrieval_done = False |
|
|
|
|
|
|
|
|
if st.session_state.start_clicked: |
|
|
|
|
|
st.success("🧠 I’ve identified the distinct datasets you’ll need for this research question.") |
|
|
with st.expander("🔹 Curious how I split your question? Click to see!", expanded=True): |
|
|
st.write(st.session_state.subtask_plan) |
|
|
|
|
|
st.success("📌 I’m ready to roll up my sleeves — shall I start finding datasets for each subtask? 🕒 This step might take a little while, so thanks for your patience!") |
|
|
|
|
|
col1, col2 = st.columns([1, 1]) |
|
|
with col1: |
|
|
if st.button("✅ Yes, go ahead", key="confirm_button"): |
|
|
st.session_state.ready_to_continue = True |
|
|
st.session_state.stop_requested = False |
|
|
with col2: |
|
|
if st.button("⛔ No, stop here", key="stop_button"): |
|
|
st.session_state.ready_to_continue = False |
|
|
st.session_state.stop_requested = True |
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.ready_to_continue: |
|
|
|
|
|
|
|
|
|
|
|
if "Dataset Need" in st.session_state.subtask_plan: |
|
|
prefix = "Dataset Need" |
|
|
else: |
|
|
prefix = "Subtask" |
|
|
|
|
|
|
|
|
pattern = rf"{prefix} \d+:.*?(?={prefix} \d+:|$)" |
|
|
subtasks = re.findall(pattern, |
|
|
st.session_state.subtask_plan, |
|
|
flags=re.DOTALL) |
|
|
|
|
|
|
|
|
if not subtasks: |
|
|
st.warning("⚠️ No dataset blocks detected in planner output.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
if not st.session_state.retrieval_done: |
|
|
progress_bar = st.progress(0) |
|
|
total = len(subtasks) |
|
|
saved_hashes = set() |
|
|
st.session_state.all_dataframes = [] |
|
|
|
|
|
|
|
|
for idx, subtask in enumerate(subtasks): |
|
|
|
|
|
with st.expander(f"🔹 Retrieving data for dataset need {idx+1}:", expanded=True): |
|
|
cleaned_subtask = "\n".join(subtask.strip().split("\n")[1:]) |
|
|
st.markdown(cleaned_subtask) |
|
|
|
|
|
|
|
|
if not st.session_state.retrieval_done: |
|
|
df = intelligent_retriever_agent(subtask, saved_hashes) |
|
|
|
|
|
if not df.empty: |
|
|
df_hash = hashlib.md5(df.to_csv(index=False).encode()).hexdigest() |
|
|
if df_hash in saved_hashes: |
|
|
st.warning("⚠️ This dataset has already been saved — skipping duplicate.") |
|
|
elif len(df) < 10: |
|
|
st.warning(f"❌ This dataset is too small — just {len(df)} rows. Skipping save.") |
|
|
else: |
|
|
saved_hashes.add(df_hash) |
|
|
df = flatten_props(df) |
|
|
df = standardize_latlon(df) |
|
|
summary = evaluate_dataset_with_gpt(subtask, df) |
|
|
st.session_state.all_dataframes.append({ |
|
|
"hash": df_hash, |
|
|
"df": df, |
|
|
"summary": summary |
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.dataframe(df.head(50)) |
|
|
save_path = save_dataset(df, f"subtask_{idx+1}") |
|
|
if save_path: |
|
|
st.markdown("**📝 Dataset Introduction:**") |
|
|
st.write(summary) |
|
|
|
|
|
with open(save_path, "rb") as f: |
|
|
st.download_button( |
|
|
label="📥 Download dataset (CSV)", |
|
|
data=f, |
|
|
file_name=os.path.basename(save_path), |
|
|
mime="text/csv", |
|
|
key=f"download_init_{idx}" |
|
|
) |
|
|
|
|
|
if 'progress_bar' in locals(): |
|
|
progress_bar.progress((idx + 1) / total) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
if idx < len(st.session_state.all_dataframes): |
|
|
entry = st.session_state.all_dataframes[idx] |
|
|
df = standardize_latlon(entry["df"]) |
|
|
st.dataframe(df.head(50)) |
|
|
|
|
|
st.markdown("**📝 Dataset Introduction:**") |
|
|
st.write(entry.get("summary", "")) |
|
|
|
|
|
tmp_path = f"/tmp/subtask_{idx+1}_display.csv" |
|
|
df.to_csv(tmp_path, index=False) |
|
|
with open(tmp_path, "rb") as f: |
|
|
st.download_button( |
|
|
label="📥 Download dataset (CSV)", |
|
|
data=f, |
|
|
file_name=os.path.basename(tmp_path), |
|
|
mime="text/csv", |
|
|
key=f"download_rerun_{idx}" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not st.session_state.retrieval_done: |
|
|
st.session_state.retrieval_done = True |
|
|
st.session_state.visualization_ready = bool(st.session_state.all_dataframes) |
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.all_dataframes: |
|
|
st.session_state.visualization_ready = True |
|
|
else: |
|
|
st.success("🎉 All subtasks completed and datasets generated!💡 Feel free to ask me more questions anytime!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.visualization_ready and not st.session_state.do_visualize: |
|
|
st.success("📊 All set! I’ve gathered the datasets. Ready to visualize them?") |
|
|
|
|
|
col1, col2 = st.columns([1, 1]) |
|
|
with col1: |
|
|
if st.button("✅ Yes, go ahead", key="viz_confirm"): |
|
|
st.session_state.do_visualize = True |
|
|
with col2: |
|
|
if st.button("⛔ No, stop here", key="viz_stop"): |
|
|
st.session_state.visualization_ready = False |
|
|
st.success("🎉 All subtasks completed and datasets generated!💡 Feel free to ask me more questions anytime!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.do_visualize: |
|
|
for i, entry in enumerate(st.session_state.all_dataframes): |
|
|
df = entry["df"] |
|
|
summary = entry.get("summary", "") |
|
|
if len(df) < 10: |
|
|
continue |
|
|
with st.expander(f"**🔹 Dataset {i + 1} Visualization**", expanded=True): |
|
|
st.markdown(f"Dataset {i + 1} Preview") |
|
|
st.dataframe(df.head(10)) |
|
|
chart_code = suggest_charts_with_gpt(df) |
|
|
if chart_code: |
|
|
try: |
|
|
exec(chart_code, {"st": st, "pd": pd, "df": df, "pdk": pdk, "folium": folium, "st_folium": st_folium}) |
|
|
except Exception as e: |
|
|
st.error(f"❌ Error running chart code: {e}") |
|
|
|
|
|
|
|
|
st.success("🎉 All subtasks completed and datasets generated!💡 Feel free to ask me more questions anytime!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if st.session_state.stop_requested: |
|
|
st.info("👍 No problem! You can review the subtasks above or revise your question.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with st.sidebar.expander("💬 Chat with Lily", expanded=True): |
|
|
|
|
|
user_msg = st.chat_input("Type your question here…", key="sidebar_chat_input") |
|
|
if user_msg: |
|
|
|
|
|
context_parts = [] |
|
|
if st.session_state.subtask_plan: |
|
|
context_parts.append("Subtasks:\n" + st.session_state.subtask_plan) |
|
|
for entry in st.session_state.all_dataframes: |
|
|
context_parts.append("Data summary:\n" + entry["summary"]) |
|
|
page_context = "\n\n".join(context_parts) |
|
|
|
|
|
|
|
|
with st.spinner("Lily is thinking…"): |
|
|
assistant_msg = gpt_chat( |
|
|
sys_msg=f"You are Lily, a research assistant. Here’s what’s on screen:\n\n{page_context}", |
|
|
user_msg=user_msg |
|
|
) |
|
|
|
|
|
|
|
|
st.session_state.chat_history.append({"role": "user", "content": user_msg}) |
|
|
st.session_state.chat_history.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
|
|
|
|
|
for msg in st.session_state.chat_history: |
|
|
if msg["role"] == "user": |
|
|
st.chat_message("user").write(msg["content"]) |
|
|
else: |
|
|
st.chat_message("assistant").write(msg["content"]) |
|
|
|