File size: 38,478 Bytes
410a29e
bd88a56
 
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14b1d1d
 
 
523b285
14b1d1d
fca1f22
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d6ab3b4
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
c48e809
 
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75a7801
410a29e
 
 
 
 
 
 
572432b
410a29e
 
 
 
 
 
 
 
fc01781
 
523b285
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84e9622
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b7f4b31
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56cf38f
83088e1
410a29e
83088e1
 
410a29e
83088e1
 
 
 
 
410a29e
 
83088e1
 
 
 
8484435
83088e1
 
 
 
 
 
 
8484435
 
1e2646d
83088e1
 
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572432b
410a29e
572432b
 
 
 
 
 
410a29e
572432b
410a29e
 
 
572432b
 
410a29e
 
 
 
 
 
 
 
 
 
572432b
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bb5846
 
 
410a29e
 
 
 
 
 
 
 
2ed41fe
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8254362
 
 
410a29e
8254362
410a29e
8254362
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596b7b0
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
84e9622
410a29e
 
84e9622
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3755e2a
c6dab26
 
 
 
 
 
410a29e
 
 
 
c6dab26
 
 
29cd917
c6dab26
 
9027cd6
8ec91c9
c6dab26
9027cd6
410a29e
 
 
 
c6dab26
 
 
 
 
 
 
 
 
a2a5666
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410a29e
c6dab26
410a29e
 
a2a5666
8fede8e
 
 
 
 
 
 
 
 
 
 
 
 
a2a5666
 
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2328ea8
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
3c99bb2
410a29e
 
 
2328ea8
410a29e
2328ea8
 
 
410a29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
import streamlit as st
import folium
from streamlit_folium import st_folium
st.set_page_config(
    page_title="🔬 Explainable Multi-Agent BioData Constructor",
    layout="centered",
    initial_sidebar_state="collapsed"
)
from neo4j import GraphDatabase
import openai
import pandas as pd
import os
import re
import hashlib
import json
import pydeck as pdk
import faiss
import numpy as np
from sklearn.preprocessing import normalize
from transformers import AutoTokenizer, AutoModel
import torch
import ast
import textwrap
import requests
# ============================== CONFIGURATION ==============================
NEO4J_URI = os.getenv("NEO4J_URI")
NEO4J_USERNAME = os.getenv("NEO4J_USERNAME")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")
openai_api_key = os.getenv("openai_api_key")

os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
# ============================== DOWNLOAD ==============================
def download_if_missing(url, local_path):
    if not os.path.exists(local_path):
        with open(local_path, "wb") as f:
            f.write(requests.get(url).content)

base_url = "https://github.com/Tianyu-yang-anna/EcoData-collector/releases/download/v1.0"
files = {
    "nodes.csv": "/tmp/nodes.csv",
    "nodes_embeddings.npy": "/tmp/nodes_embeddings.npy",
    "relationships.csv": "/tmp/relationships.csv",
    "relationships_embeddings.npy": "/tmp/relationships_embeddings.npy"
}

for fname, path in files.items():
    download_if_missing(f"{base_url}/{fname}", path)

# ============================== NEO4J DRIVER ==============================
@st.cache_resource(show_spinner=False)
def create_driver():
    try:
        driver = GraphDatabase.driver(
            NEO4J_URI,
            auth=(NEO4J_USERNAME, NEO4J_PASSWORD)
        )
        with driver.session() as session:
            session.run("RETURN 1")
        return driver
    except Exception as e:
        st.error(f"🔴 Neo4j connection failed: {e}")
        return None

driver = create_driver()
# ============================== SIMPLE GPT HELPER ==============================
openai_client = openai.OpenAI(api_key=openai_api_key)

def gpt_chat(sys_msg: str, user_msg: str, **kwargs):
    rsp = openai_client.chat.completions.create(
        model="gpt-4o",
        messages=[{"role": "system", "content": sys_msg}, {"role": "user", "content": user_msg}],
        **kwargs
    )
    return rsp.choices[0].message.content.strip()

# ============================== EMBEDDING ENCODER ==============================
class SimpleEncoder:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained("/app/model")
        self.model = AutoModel.from_pretrained("/app/model").to(self.device)
        self.model.eval()

    def encode(self, texts, batch_size: int = 16):
        embeddings = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i : i + batch_size]
            with torch.no_grad():
                inputs = self.tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(self.device)
                outputs = self.model(**inputs)
                batch_emb = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
                embeddings.append(batch_emb)
        return np.vstack(embeddings)


@st.cache_resource(show_spinner=False)
def get_encoder():
    return SimpleEncoder()

# ============================== FAISS INDEX LOADING ==============================
csv_file_pairs = [
    ("/tmp/nodes.csv", "/tmp/nodes_embeddings.npy"),
    ("/tmp/relationships.csv", "/tmp/relationships_embeddings.npy"),
]

for csv_path, npy_path in csv_file_pairs:
    if not os.path.exists(npy_path):
        st.error(f"❌ Embedding file not found: {npy_path}")
        st.stop()

@st.cache_resource(show_spinner=False)
def load_embeddings_and_faiss_indexes(file_pairs):
    index_list, metadatas = [], []
    for csv_path, npy_path in file_pairs:
        try:
            df = pd.read_csv(csv_path).fillna("")
            emb = np.load(npy_path).astype("float32")
            index = faiss.IndexFlatIP(emb.shape[1])
            if faiss.get_num_gpus() > 0:
                res = faiss.StandardGpuResources()
                index = faiss.index_cpu_to_gpu(res, 0, index)
            index.add(emb)
            index_list.append(index)
            metadatas.append(df)
        except Exception as e:
            st.warning(f"⚠️ Failed to load {csv_path} / {npy_path}: {e}")
            index_list.append(None)
            metadatas.append(pd.DataFrame())
    return index_list, metadatas

csv_faiss_indexes, csv_metadatas = load_embeddings_and_faiss_indexes(csv_file_pairs)

# ============================== DATAFRAME UTILITIES ==============================

def flatten_props(df: pd.DataFrame) -> pd.DataFrame:
    if "props" not in df.columns:
        return df
    try:
        props_df = df["props"].apply(ast.literal_eval).apply(pd.Series)
        out = pd.concat([df.drop(columns=["props"]), props_df], axis=1)
        # st.write("✅ props flattened, new columns:", list(props_df.columns))
        return out
    except Exception as e:
        st.warning(f"⚠️ Failed to parse props column: {e}")
        return df

def unpack_singletons(df: pd.DataFrame) -> pd.DataFrame:
    for col in df.columns:
        if df[col].apply(lambda x: isinstance(x, (list, tuple)) and len(x) == 1).any():
            df[col] = df[col].apply(lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) == 1 else x)
    return df

def standardize_latlon(df: pd.DataFrame) -> pd.DataFrame:
    """
    - 统一列名到 latitudes / longitudes
    - 若出现同名重复列,保留第一列并删除其余
    - longitudes 位置保持不动,把 latitudes 放到其右侧
    """
    # ---------- ① 统一列名 ----------
    col_map = {}
    for col in df.columns:
        low = col.lower()
        if "lat" in low and "lon" not in low:
            col_map[col] = "latitudes"
        elif ("lon" in low or "lng" in low):
            col_map[col] = "longitudes"
    df = df.rename(columns=col_map)

    # ---------- ② 处理重复列 ----------
    # pandas 会把重名列自动加 .1 .2 …,用 .str.replace 统一判断
    while df.columns.duplicated().any():
        dup_col = df.columns[df.columns.duplicated()][0]
        # 保留出现的第一列,其余同名全部丢掉
        first_idx = list(df.columns).index(dup_col)
        keep = [True] * len(df.columns)
        for i, c in enumerate(df.columns):
            if c == dup_col and i != first_idx:
                keep[i] = False
        df = df.loc[:, keep]

    # ---------- ③ 转数值 ----------
    for c in ("latitudes", "longitudes"):
        if c in df.columns and not isinstance(df[c], pd.Series):
            # 出现重复但未被处理时仍可能是 DataFrame,再取第一列
            df[c] = df[c].iloc[:, 0]
        if c in df.columns:
            df[c] = df[c].apply(
                lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) == 1 else x
            )
            df[c] = pd.to_numeric(df[c], errors="coerce")

    # ---------- ④ 调整顺序:latitudes 紧跟 longitudes ----------
    if {"longitudes", "latitudes"}.issubset(df.columns):
        cols = list(df.columns)
        lon_idx = cols.index("longitudes")
        lat_idx = cols.index("latitudes")
        if lat_idx != lon_idx + 1:
            cols.pop(lat_idx)
            cols.insert(lon_idx + 1, "latitudes")
            df = df[cols]

    return df



# ===== CSV fallback 查询 =====
@st.cache_data(show_spinner=False)
def rag_csv_fallback(subtask, top_k=2000):
    encoder = get_encoder()
    query_vec = encoder.encode([subtask])
    query_vec = normalize(query_vec, axis=1).astype("float32")
    if not np.any(query_vec):
        return pd.DataFrame()
    all_results = []
    for index, metadata in zip(csv_faiss_indexes, csv_metadatas):
        if index is None or metadata.empty:
            continue
        distances, indices = index.search(query_vec, top_k)
        retrieved = metadata.iloc[indices[0]].copy()
        all_results.append(retrieved)
    if all_results:
        return pd.concat(all_results).drop_duplicates().reset_index(drop=True)
    return pd.DataFrame()



def generate_cypher_with_gpt(subtask: str) -> str:
    prompt = f"""
You are an expert Cypher query generator for a Neo4j biodiversity database. The schema is as follows:

Node Types and Properties:
- Observation: animal_name, date, latitude, longitude
- Species: name, species_full_name
- Site: name
- County: name
- State: name
- Hurricane: name
- Policy: title, description
- ClimateEvent: event_type, date
- TemperatureReading: value, date, location
- Precipitation: amount, date, location

Relationship Types:
- OBSERVED_IN: (Observation)-[:OBSERVED_IN]->(Site)
- OBSERVED_ORGANISM: (Observation)-[:OBSERVED_ORGANISM]->(Species)
- BELONGS_TO: (Site)-[:BELONGS_TO]->(County)
- IN_COUNTY: (Observation)-[:IN_COUNTY]->(County)
- IN_STATE: (County)-[:IN_STATE]->(State)
- interactsWith: (Species)-[:interactsWith]->(Species)
- preysOn: (Species)-[:preysOn]->(Species)

Your task is to generate a **precise and efficient** Cypher query for the following subtask:
"{subtask}"

Guidelines:
- Do NOT return all nodes of a type (e.g., all Species) unless the subtask explicitly asks for it.
- If a location (county/state) is mentioned or implied, include location filtering using IN_COUNTY, IN_STATE, or BELONGS_TO.
- If the subtask implies a taxonomic or common name group (e.g., frog, snake, salmon), apply CONTAINS or STARTS WITH filters on Species.name or species_full_name, using toLower(...) for case-insensitive matching.
- If the subtask includes a time range, include date filtering.
- Prefer using DISTINCT to avoid redundant results.
- Only return fields that are clearl y needed to fulfill the subtask.

Return your response strictly as a **JSON object** with the following fields:
- "intent": a short description of what the query does
- "cypher_query": the Cypher query
- "fields": a list of returned field names (e.g., ["species", "county", "date"])

Do not include any explanation or commentary—only return the JSON object.
"""

    
    client = openai.OpenAI(api_key=os.getenv("openai_api_key"))
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[{"role": "user", "content": prompt}],
        temperature=0
    )
    content = response.choices[0].message.content.strip()
    content = re.sub(r"^(json|python)?", "", content, flags=re.IGNORECASE).strip()
    content = re.sub(r"$", "", content).strip()

    try:
        cypher_json = json.loads(content)
        return cypher_json["cypher_query"]
    except Exception as e:
        return ""
    

def intelligent_retriever_agent(subtask, saved_hashes=None):
    if saved_hashes is None:
        saved_hashes = set()
    st.success("🔍 Attempting to retrieve data from the Ecodata knowledge graph…")
    cypher_query = generate_cypher_with_gpt(subtask)
    cypher_df = pd.DataFrame()
    if cypher_query.strip():
        st.code(cypher_query, language="cypher")
        try:
            query = re.sub(r"(?i)LIMIT\s+\d+\s*$", "", cypher_query)
            with driver.session() as session:
                result = session.run(query)
                cypher_df = pd.DataFrame(result.data())
        except Exception as e:
            st.error(f"🚨 Cypher execution error: {e}")
            st.code(query, language="cypher")
    # decide fallback
    fallback_needed = False
    if cypher_df.empty:
        # st.warning("⚠️ Cypher query returned no data. Trying CSV fallback…")
        fallback_needed = True
    else:
        df_hash = hashlib.md5(cypher_df.to_csv(index=False).encode()).hexdigest()
        st.write(f"ℹ️ Cypher rows: {len(cypher_df)} | duplicate?: {df_hash in saved_hashes}")
        if df_hash in saved_hashes or len(cypher_df) < 10:
            fallback_needed = True
    if fallback_needed:
        csv_df = rag_csv_fallback(subtask)
        if not csv_df.empty:
            csv_df = flatten_props(csv_df)
            csv_df = unpack_singletons(csv_df)
            csv_df = standardize_latlon(csv_df) 
            # st.success("✅ CSV fallback successful.")
            return csv_df
        st.warning("❌ CSV fallback also returned nothing.")
        return pd.DataFrame()
    # good cypher
    st.success("✅ Cypher query successful. Using Cypher result.")
    cypher_df = flatten_props(cypher_df)
    cypher_df = unpack_singletons(cypher_df)
    cypher_df = standardize_latlon(cypher_df)
    if "species" not in cypher_df.columns and "animal_name" in cypher_df.columns:
        cypher_df["species"] = cypher_df["animal_name"]
    if "date" in cypher_df.columns:
        cypher_df["date"] = pd.to_datetime(cypher_df["date"], errors="coerce")
    cypher_df.rename(columns={"latitudes": "latitude", "longitudes": "longitude", "lat": "latitude", "lon": "longitude"}, inplace=True)
    for col in ("latitude", "longitude"):
        if col in cypher_df.columns:
            cypher_df[col] = pd.to_numeric(cypher_df[col], errors="coerce")
    return cypher_df


def planner_agent(question: str) -> str:
    prompt = f"""
You are a **research‑data planning assistant**.

------------------------  📝  TASK  ------------------------
Your job is to list the **separate data sets** a researcher must collect
to answer the research question below.

*Each data set* should be focused on one clearly defined entity or
phenomenon (e.g. "Tracks of hurricanes affecting Florida since 1950",
"Geo‑tagged snake observations in Florida 2000‑present").

--------------------  📋  OUTPUT FORMAT  --------------------
Write 1–3 blocks.  For **each** block use *all* four lines exactly:

Dataset Need X: <Concise title, ≤ 10 words>  
Description: <Why this data matters—1 short sentence>

⚠️  Do NOT add extra lines or markdown.  
⚠️  Keep variable names short; no code blocks; no quotes.

--------------------  🔍  RESEARCH QUESTION  --------------------
{question}
"""
    rsp = openai_client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": "You are an expert research planner."},
            {"role": "user", "content": prompt}
        ],
        temperature=0.2
    )
    return rsp.choices[0].message.content.strip()



def evaluate_dataset_with_gpt(subtask: str, df: pd.DataFrame, client=openai_client) -> str:
    max_columns = 50
    selected_cols = df.columns[:max_columns]
    column_info = {col: str(df[col].dtype) for col in selected_cols}
    sample_rows = df.head(3)[selected_cols].to_dict(orient="records")   # take 3 example rows

    prompt = f"""
You are a data‑validation assistant. Decide whether the dataset below is useful for the research subtask.

===== TASK =====
Subtask: "{subtask}"

===== DATASET PREVIEW =====
Schema (first {len(selected_cols)} columns):
{json.dumps(column_info, indent=10)}
Sample rows (10 max):
{json.dumps(sample_rows, indent=10)}

===== OUTPUT INSTRUCTIONS (follow strictly) =====
Case A – Relevant:
• Write exactly two sentences, each no more than 30 words.
• Summarize what the dataset contains and why it helps the subtask.
• Do not mention column names or list individual rows.

Case B – Not relevant:
• Write one or two sentences, each no more than 30 words, **describing only what the dataset contains**.
• Do **not** mention the subtask, relevance, suitability, limitations, or missing information (avoid phrases like “not related,” “does not focus,” “irrelevant,” etc.).
• After the sentences, output the header **Additionally, here are some external resources you might find helpful:** on a new line. Format your output in markdown as:
- [Name of Source](URL)
• Then list 2–3 bullet points, each on its own line, starting with “- ” followed immediately by a URL likely to contain the needed data.
• No additional commentary.



General rules:
Plain text only — no code fences. Markdown link syntax (`[text](url)`) is allowed.
"""

    rsp = client.chat.completions.create(
        model="gpt-4o",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.3,
    )
    return rsp.choices[0].message.content.strip()

# def evaluate_dataset_with_gpt(subtask: str, df: pd.DataFrame,client=openai_client) -> str:
#     # 只选择前 N 个字段,避免超长 token
#     max_columns = 10
#     selected_columns = df.columns[:max_columns]
    
#     # 提取字段名及其数据类型
#     column_info = {col: str(df[col].dtype) for col in selected_columns}
    
#     # 提取前 3 行示例
#     sample_data = df.head(50)[selected_columns].to_dict(orient="records")
    
#     # 构建 prompt
#     prompt = f"""
# You are a data validation assistant. Your task is to summarize what this dataset represents.

# Subtask: {subtask}

# Here are the dataset's column names and data types:
# {json.dumps(column_info, indent=2)}

# Here are a few sample rows:
# {json.dumps(sample_data, indent=2)}

# Your response should be concise (2-3 sentences). 
# Focus on the dataset's content and how it might help with the subtask. 
# Do not list column names or describe individual rows.
# 下面是你的输出格式:
# 如果你判断数据和data needed相关,那么输出2-3 sentences介绍该数据集。
# 如果你判断数据和data needed不相关,那么输出2-4条外部资源的链接。
# """
#     # 调用 GPT-4o
#     rsp = client.chat.completions.create(                     
#         model="gpt-4o",
#         messages=[{"role": "user", "content": prompt}],
#         temperature=0.3
#     )
#     return rsp.choices[0].message.content.strip()




def external_resource_recommender(subtask: str, client=openai_client) -> str:
    prompt = f"""
You are a helpful research assistant. Your task is to recommend **three reliable, publicly accessible online datasets or data repositories** that can assist with the following scientific subtask:

{subtask}

Only include sources that are:
- Trusted (e.g., government, academic, or well-established platforms)
- Relevant to the topic
- Accessible without login when possible

Format your answer strictly in markdown:
- [Name of Source](URL)
- [Name of Source](URL)
- [Name of Source](URL)

Do not include any explanations or extra text—only the list.
"""
    rsp = client.chat.completions.create(                      
        model="gpt-4o",
        messages=[{"role": "user", "content": prompt}],
        temperature=0.3
    )
    return rsp.choices[0].message.content.strip()




def fallback_query_router(subtask: str, driver) -> pd.DataFrame:
    text = subtask.lower()

    with driver.session() as session:

        # --- 1. 物种“where…observed/found…” ---
        if "where" in text and ("observed" in text or "found" in text):
            query = """
            MATCH (o:Observation)-[:OBSERVED_ORGANISM]->(s:Species)
            RETURN s.name AS species, o.site_name AS location, o.date AS date
            ORDER BY o.date DESC
            """

        # --- 2. before / after 某一年 ---
        elif "before" in text or "after" in text:
            years = re.findall(r'\b(19|20)\d{2}\b', text)
            if years:
                op = "<" if "before" in text else ">"
                query = f"""
                MATCH (o:Observation)-[:OBSERVED_ORGANISM]->(s:Species)
                WHERE o.date {op} date('{years[0]}-01-01')
                RETURN s.name AS species, o.site_name AS location, o.date AS date
                ORDER BY o.date DESC
                """
            else:
                query = "RETURN 1"

        # --- 3. 飓风相关 ---
        elif "hurricane" in text:
            query = """
            MATCH (o:Observation)-[:OBSERVED_AT]->(h:Hurricane),
                  (o)-[:OBSERVED_ORGANISM]->(s:Species),
                  (o)-[:OBSERVED_IN]->(site)-[:BELONGS_TO]->(c:County)-[:IN_STATE]->(st:State)
            WHERE st.name = 'Florida'
            RETURN h.name AS hurricane,
                   s.name AS species,
                   site.name AS site,
                   c.name AS county,
                   o.date AS date
            ORDER BY o.date DESC
            """

        # --- 4. 捕食 / predator ---
        elif "preys on" in text or "predator" in text:
            query = """
            MATCH (s1:Species)-[:preysOn]->(s2:Species)
            RETURN s1.name AS predator, s2.name AS prey
            """

        # --- 5. 默认兜底 ---
        else:
            query = """
            MATCH (o:Observation)
            RETURN o.animal_name AS species, o.site_name AS location, o.date AS date
            """

        # --- 执行查询 ---
        result = session.run(query)
        df = pd.DataFrame(result.data())

        if df.empty:
            st.info("🌐 I couldn't find relevant data in KN‑Wildlife. Let me check external sources for you...")
            suggestions = external_resource_recommender(subtask)
            st.markdown(suggestions)

        return df


def save_dataset(df: pd.DataFrame, filename: str) -> str:
    if len(df) < 10:
        st.warning(f"❌ Dataset too small to save: only {len(df)} rows.")
        return ""
    save_dir = "/tmp/saved_datasets"
    os.makedirs(save_dir, exist_ok=True)
    path = f"{save_dir}/{filename}.csv"
    if os.path.exists(path):
        old_hash = hashlib.md5(open(path, 'rb').read()).hexdigest()
        new_hash = hashlib.md5(df.to_csv(index=False).encode()).hexdigest()
        if old_hash == new_hash:
            st.info(f"ℹ️ Dataset saved: {filename}.csv")
            return path
    df.to_csv(path, index=False)
    st.info(f"✅ Dataset saved: {filename}.csv")

    return path
# ===================== CHART SUGGESTION (MODIFIED MAP SECTION) =====================

def suggest_charts_with_gpt(df: pd.DataFrame) -> str:
    """Generate Streamlit chart code for automatic visualisation."""
    try:
        # st.write("🟢 COLS‑DEBUG:", list(df.columns))

        # Ensure dates are parsed
        if "date" in df.columns:
            df["date"] = df["date"].apply(lambda x: x[0] if isinstance(x, (list, tuple)) and len(x) == 1 else x)
            df["date"] = pd.to_datetime(df["date"], errors="coerce")

        if "animal_name" in df.columns and "species" not in df.columns:
            df["species"] = df["animal_name"]

        df.rename(columns={"latitudes": "latitude", "longitudes": "longitude"}, inplace=True)

        chart_code = """
# --- Species Bar Chart ---
if 'species' in df.columns:
    st.markdown('📊 Count of Observations by Species')
    try:
        species_counts = df['species'].astype(str).value_counts()
        st.bar_chart(species_counts)
    except Exception as e:
        st.warning(f'⚠️ Could not render species chart: {e}')

# --- Timeline Line Chart ---
if 'date' in df.columns:
    st.markdown('📈 Observations Over Time')
    try:
        timeline = df['date'].dropna().value_counts().sort_index()
        st.line_chart(timeline)
    except Exception as e:
        st.warning(f'⚠️ Could not render date chart: {e}')

# --- Map Visualisation (highlight all points) ---
if 'latitude' in df.columns and 'longitude' in df.columns:
    st.markdown('🗺️ Observation Locations on Map')
    try:
        coords = df[['latitude', 'longitude']].dropna()
        coords = coords[(coords['latitude'].between(-90, 90)) & (coords['longitude'].between(-180, 180))]
        
        if len(coords) == 0:
            raise Exception('⚠️ No valid coordinates to plot on the map.')
        else:
            # 计算中心点
            center = [coords['latitude'].mean(), coords['longitude'].mean()]
            m = folium.Map(location=center, zoom_start=5)
        
            # 添加散点
            for _, row in coords.iterrows():
                folium.CircleMarker(
                    location=[row['latitude'], row['longitude']],
                    radius=5,
                    color='green',
                    fill=True,
                    fill_color='green',
                    fill_opacity=0.7,
                ).add_to(m)
        
            st_folium(m, width=700, height=500)
    except Exception as e:
        st.warning(f'⚠️ Could not render map: {e}')
"""
        return textwrap.dedent(chart_code)
    except Exception as outer_error:
        return f"st.warning('❌ Chart generation failed: {outer_error}')"




# ========= UI layout and connection ==========
if "chat_history" not in st.session_state:
    st.session_state.chat_history = []
    
# st.set_page_config(
#     page_title="🔬 Explainable Multi-Agent BioData Constructor",
#     layout="centered",
#     initial_sidebar_state="collapsed"
# )

# ——— 自定义主容器最大宽度 ———
st.markdown(
    """
    <style>
        /* 针对正文文字 */
        html, body, .block-container, .markdown-text-container {
            font-size: 19px !important;     /* ← 这里改数字 */
            line-height: 1.6 !important;
        }
        /* 把默认窄屏的 max-width(约700px)改成 1400px,视需要可调整 */
        .block-container {
            max-width: 1600px;
        }
    </style>
    """,
    unsafe_allow_html=True
)

st.title("🐾 Quest2DataAgent_EcoData")


st.success("""
👋 Hi there! I’m **Lily**, your research assistant bot 🤖. I’m here to help you explore data sources related to your **complex research question**. Let’s work together to find the information you need!

💡 You can start by entering a research question like:

- *In Florida, how do hurricanes affect the distribution of snakes?*  
- *How does precipitation impact salmon abundance in freshwater ecosystems?*  
- *How do climate change and urbanization jointly affect bird migration and diversity in Florida?*
""")

if driver:
    st.success("🟢 Connected to **Ecodata** — a Neo4j-powered biodiversity graph focused on  species and ecosystems.   I’ll start by checking what relevant data we already have in Ecodata to support your research.")

else:
    st.error("🔴 Failed to connect to Ecodata! Please fix connection first.")
    st.stop()

question = st.text_area("Enter your research question:", "")

# 初始化状态变量
if "start_clicked" not in st.session_state:
    st.session_state.start_clicked = False
if "subtask_plan" not in st.session_state:
    st.session_state.subtask_plan = ""
if "ready_to_continue" not in st.session_state:
    st.session_state.ready_to_continue = False
if "stop_requested" not in st.session_state:
    st.session_state.stop_requested = False
if "visualization_ready" not in st.session_state:
    st.session_state.visualization_ready = False
if "do_visualize" not in st.session_state:
    st.session_state.do_visualize = False
if "all_dataframes" not in st.session_state:
    st.session_state.all_dataframes = [] 
if "retrieval_done" not in st.session_state:
    st.session_state.retrieval_done = False

# 点击按钮,触发子任务分解
if st.button("Let’s start") and question.strip():
    st.session_state.start_clicked = True
    st.session_state.subtask_plan = planner_agent(question)
    st.session_state.ready_to_continue = False
    st.session_state.stop_requested = False
    st.session_state.visualization_ready = False
    st.session_state.do_visualize = False
    st.session_state.all_dataframes = []
    st.session_state.retrieval_done = False

# 阶段一:展示子任务
if st.session_state.start_clicked:
    # st.success("🧠 Now, I’ll break down your research question into several focused subtasks.")
    st.success("🧠 I’ve identified the distinct datasets you’ll need for this research question.")
    with st.expander("🔹 Curious how I split your question? Click to see!", expanded=True):
        st.write(st.session_state.subtask_plan)

    st.success("📌 I’m ready to roll up my sleeves — shall I start finding datasets for each subtask? 🕒 This step might take a little while, so thanks for your patience!")

    col1, col2 = st.columns([1, 1])
    with col1:
        if st.button("✅ Yes, go ahead", key="confirm_button"):
            st.session_state.ready_to_continue = True
            st.session_state.stop_requested = False
    with col2:
        if st.button("⛔ No, stop here", key="stop_button"):
            st.session_state.ready_to_continue = False
            st.session_state.stop_requested = True


# ---------- 阶段二:数据检索 & 渲染 ----------
if st.session_state.ready_to_continue:

    # ① 先确定 Planner 输出使用的前缀
    #    这里假设只有两种可能:Subtask / Dataset Need
    if "Dataset Need" in st.session_state.subtask_plan:
        prefix = "Dataset Need"
    else:
        prefix = "Subtask"

    # ② 用 f-string 拼正则(rf = raw‑formatted)
    pattern = rf"{prefix} \d+:.*?(?={prefix} \d+:|$)"
    subtasks = re.findall(pattern,
                          st.session_state.subtask_plan,
                          flags=re.DOTALL)

    # 如果 Planner 没输出任何块,给个提示
    if not subtasks:
        st.warning("⚠️ No dataset blocks detected in planner output.")
        st.stop()

    # 检索只执行一次
    if not st.session_state.retrieval_done:                      # ★
        progress_bar = st.progress(0)
        total = len(subtasks)
        saved_hashes = set()
        st.session_state.all_dataframes = []


    for idx, subtask in enumerate(subtasks):
        # with st.expander(f"🔹 Retrieving data for subtask {idx+1}:", expanded=True):
        with st.expander(f"🔹 Retrieving data for dataset need {idx+1}:", expanded=True):
            cleaned_subtask = "\n".join(subtask.strip().split("\n")[1:])
            st.markdown(cleaned_subtask)

            # ---------- 首次运行:真正检索 ----------
            if not st.session_state.retrieval_done:              # ★
                df = intelligent_retriever_agent(subtask, saved_hashes)

                if not df.empty:
                    df_hash = hashlib.md5(df.to_csv(index=False).encode()).hexdigest()
                    if df_hash in saved_hashes:
                        st.warning("⚠️ This dataset has already been saved — skipping duplicate.")
                    elif len(df) < 10:
                        st.warning(f"❌ This dataset is too small — just {len(df)} rows. Skipping save.")
                    else:
                        saved_hashes.add(df_hash)
                        df = flatten_props(df)
                        df = standardize_latlon(df)
                        summary = evaluate_dataset_with_gpt(subtask, df)
                        st.session_state.all_dataframes.append({              
                            "hash": df_hash,
                            "df": df,
                            "summary": summary
                            })
                        # st.dataframe(df.head(50))
                        # save_path = save_dataset(df, f"subtask_{idx+1}")
                        # if save_path:
                        #     # summary = evaluate_dataset_with_gpt(subtask, df)
                        #     st.markdown("**📝 Dataset Introduction:**")
                        #     st.write(summary)
                        st.dataframe(df.head(50))
                        save_path = save_dataset(df, f"subtask_{idx+1}")
                        if save_path:
                            st.markdown("**📝 Dataset Introduction:**")
                            st.write(summary)
                            # 添加下载按钮
                            with open(save_path, "rb") as f:
                                st.download_button(
                                    label="📥 Download dataset (CSV)",
                                    data=f,
                                    file_name=os.path.basename(save_path),
                                    mime="text/csv",
                                    key=f"download_init_{idx}"
                                )

                if 'progress_bar' in locals():
                    progress_bar.progress((idx + 1) / total)

            # ---------- 之后 rerun:只展示 ----------
            # else:                                               # ★
            #     if idx < len(st.session_state.all_dataframes):
            #         # _hash, df = st.session_state.all_dataframes[idx]
            #         # df = standardize_latlon(df)  
            #         # st.dataframe(df.head(50))
            #         entry = st.session_state.all_dataframes[idx]          # ➕ 新行
            #         df = standardize_latlon(entry["df"])
            #         st.dataframe(df.head(50))
            #         st.write(entry.get("summary", ""))
            # else:  # ★
            #     if idx < len(st.session_state.all_dataframes):
            #         entry = st.session_state.all_dataframes[idx]
            #         df = standardize_latlon(entry["df"])
            #         st.dataframe(df.head(50))
            #         st.write(entry.get("summary", ""))
                    
            #         # 添加下载按钮
            #         tmp_path = f"/tmp/subtask_{idx+1}_display.csv"
            #         df.to_csv(tmp_path, index=False)
            #         with open(tmp_path, "rb") as f:
            #             st.download_button(
            #                 label="📥 Download dataset (CSV)",
            #                 data=f,
            #                 file_name=os.path.basename(tmp_path),
            #                 mime="text/csv",
            #                 key=f"download_rerun_{idx}" 
            #             )

            else:  
                if idx < len(st.session_state.all_dataframes):
                    entry = st.session_state.all_dataframes[idx]
                    df = standardize_latlon(entry["df"])
                    st.dataframe(df.head(50))
            
                    st.markdown("**📝 Dataset Introduction:**")
                    st.write(entry.get("summary", ""))
                    # 添加下载按钮
                    tmp_path = f"/tmp/subtask_{idx+1}_display.csv"
                    df.to_csv(tmp_path, index=False)
                    with open(tmp_path, "rb") as f:
                        st.download_button(
                            label="📥 Download dataset (CSV)",
                            data=f,
                            file_name=os.path.basename(tmp_path),
                            mime="text/csv",
                            key=f"download_rerun_{idx}" 
                        )



    # 检索完成后打标记
    if not st.session_state.retrieval_done:                     # ★
        st.session_state.retrieval_done = True
        st.session_state.visualization_ready = bool(st.session_state.all_dataframes)



    if st.session_state.all_dataframes:
        st.session_state.visualization_ready = True
    else:
        st.success("🎉 All subtasks completed and datasets generated!💡 Feel free to ask me more questions anytime!")
        # st.success("🎉 All subtasks completed and datasets generated!")
        # st.success("💡 Feel free to ask Lily more questions anytime!")

# 阶段三:是否进行可视化选择
if st.session_state.visualization_ready and not st.session_state.do_visualize:
    st.success("📊 All set! I’ve gathered the datasets. Ready to visualize them?")

    col1, col2 = st.columns([1, 1])
    with col1:
        if st.button("✅ Yes, go ahead", key="viz_confirm"):
            st.session_state.do_visualize = True
    with col2:
        if st.button("⛔ No, stop here", key="viz_stop"):
            st.session_state.visualization_ready = False
            st.success("🎉 All subtasks completed and datasets generated!💡 Feel free to ask me more questions anytime!")
            # st.success("🎉 All subtasks completed and datasets generated!")
            # st.success("💡 Feel free to ask Lily more questions anytime!")


# 阶段三:数据可视化
if st.session_state.do_visualize:
    for i, entry in enumerate(st.session_state.all_dataframes): 
        df = entry["df"]
        summary = entry.get("summary", "")  
        if len(df) < 10:
            continue
        with st.expander(f"**🔹 Dataset {i + 1} Visualization**", expanded=True):
            st.markdown(f"Dataset {i + 1} Preview")
            st.dataframe(df.head(10))
            chart_code = suggest_charts_with_gpt(df)
            if chart_code:
                try:
                    exec(chart_code, {"st": st, "pd": pd, "df": df, "pdk": pdk, "folium": folium, "st_folium": st_folium})
                except Exception as e:
                    st.error(f"❌ Error running chart code: {e}")


    st.success("🎉 All subtasks completed and datasets generated!💡 Feel free to ask me more questions anytime!")




if st.session_state.stop_requested:
    st.info("👍 No problem! You can review the subtasks above or revise your question.") 



# —— 在侧边栏插入 ChatGPT 风格聊天面板 ——  
with st.sidebar.expander("💬 Chat with Lily", expanded=True):
    # 聊天输入框
    user_msg = st.chat_input("Type your question here…", key="sidebar_chat_input")
    if user_msg:
        # 拼当前页面上下文
        context_parts = []
        if st.session_state.subtask_plan:
            context_parts.append("Subtasks:\n" + st.session_state.subtask_plan)
        for entry in st.session_state.all_dataframes:
            context_parts.append("Data summary:\n" + entry["summary"])
        page_context = "\n\n".join(context_parts)

        # 调用 GPT helper
        with st.spinner("Lily is thinking…"):
            assistant_msg = gpt_chat(
                sys_msg=f"You are Lily, a research assistant. Here’s what’s on screen:\n\n{page_context}",
                user_msg=user_msg
            )

        # 保存对话
        st.session_state.chat_history.append({"role": "user",      "content": user_msg})
        st.session_state.chat_history.append({"role": "assistant", "content": assistant_msg})

    # 渲染历史对话
    for msg in st.session_state.chat_history:
        if msg["role"] == "user":
            st.chat_message("user").write(msg["content"])
        else:
            st.chat_message("assistant").write(msg["content"])