ashaddams commited on
Commit
d94c37b
·
verified ·
1 Parent(s): f373e28

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +652 -0
app.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ===============================================
2
+ # Algae Yield Predictor — Uncertainty + Response Plot
3
+ # (Hugging Face Spaces–ready)
4
+ # ===============================================
5
+
6
+ import re
7
+ import numpy as np
8
+ import pandas as pd
9
+ import gradio as gr
10
+ import matplotlib.pyplot as plt
11
+ from pathlib import Path
12
+ from difflib import get_close_matches
13
+
14
+ from sklearn.preprocessing import LabelEncoder
15
+ from sklearn.impute import SimpleImputer
16
+ from sklearn.neighbors import NearestNeighbors
17
+
18
+ from catboost import CatBoostRegressor
19
+ from gradio.themes import Soft
20
+
21
+ # -----------------------------
22
+ # Paths (relative to repo root)
23
+ # -----------------------------
24
+ HERE = Path(__file__).parent
25
+ RAW_PATH = HERE / "ai_al.csv" # required
26
+ DOI_PATH = HERE / "doi.csv" # optional
27
+ MODEL_DIR = HERE / "models" # optional pre-trained .cbm
28
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
29
+
30
+ # -----------------------------
31
+ # Helpers
32
+ # -----------------------------
33
+ def extract_first_float(x: str):
34
+ if pd.isna(x): return np.nan
35
+ s = str(x)
36
+ m = re.search(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", s)
37
+ return float(m.group(0)) if m else np.nan
38
+
39
+ def parse_cycle_first(x: str):
40
+ if pd.isna(x): return np.nan
41
+ s = str(x)
42
+ m = re.search(r"(\d+(?:\.\d+)?)\s*:\s*(\d+(?:\.\d+)?)", s)
43
+ return float(m.group(1)) if m else extract_first_float(s)
44
+
45
+ def coerce_numeric(series: pd.Series, mode: str = "float"):
46
+ return series.apply(parse_cycle_first if mode == "cycle_first" else extract_first_float)
47
+
48
+ def normalize_str(x):
49
+ if pd.isna(x): return "nan"
50
+ return str(x).strip().lower()
51
+
52
+ # -----------------------------
53
+ # Curated suggestions
54
+ # -----------------------------
55
+ SPECIES_SUGGESTIONS = {
56
+ "a. platensis": {
57
+ "biomass": {"light": "60–300", "days": "15–25"},
58
+ "lipid": {"light": "High light intensity (stress)", "days": "15–25"},
59
+ "protein": {"light": "60–300", "days": "12–18"},
60
+ "carb": {"light": "60–300", "days": "15–25"},
61
+ },
62
+ "c. pyrenoidosa": {
63
+ "biomass": {"light": "50–150", "days": "12–25"},
64
+ "lipid": {"light": "High light intensity (stress)", "days": "12–25"},
65
+ "protein": {"light": "50–150", "days": "12–18"},
66
+ "carb": {"light": "50–150", "days": "12–25"},
67
+ },
68
+ "c. sorokiniana": {
69
+ "biomass": {"light": "60–300", "days": "15–25"},
70
+ "lipid": {"light": "High light intensity (stress)", "days": "15–25"},
71
+ "protein": {"light": "60–300", "days": "12–18"},
72
+ "carb": {"light": "60–300", "days": "15–25"},
73
+ },
74
+ "c. variabilis": {
75
+ "biomass": {"light": "60–250", "days": "15–25"},
76
+ "lipid": {"light": "High light intensity (stress)", "days": "15–25"},
77
+ "protein": {"light": "60–250", "days": "12–18"},
78
+ "carb": {"light": "60–250", "days": "15–25"},
79
+ },
80
+ "c. vulgaris": {
81
+ "biomass": {"light": "60–300", "days": "12–21"},
82
+ "lipid": {"light": "High light intensity (stress)", "days": "15–21"},
83
+ "protein": {"light": "60–300", "days": "12–18"},
84
+ "carb": {"light": "60–300", "days": "12–21"},
85
+ },
86
+ "c. zofingiensis": {
87
+ "biomass": {"light": "50–150", "days": "25–30"},
88
+ "lipid": {"light": "High light intensity (stress)", "days": "25–30"},
89
+ "protein": {"light": "50–150", "days": "25–30"},
90
+ "carb": {"light": "50–150", "days": "25–30"},
91
+ },
92
+ "h. pluvialis": {
93
+ "biomass": {"light": "50–250", "days": "25–30"},
94
+ "lipid": {"light": "High light intensity (stress)", "days": "25–30"},
95
+ "protein": {"light": "50–250", "days": "25–30"},
96
+ "carb": {"light": "50–250", "days": "25–30"},
97
+ },
98
+ "p. purpureum": {
99
+ "biomass": {"light": "100–250", "days": "17–19"},
100
+ "lipid": {"light": "High light intensity (stress)", "days": "17–19"},
101
+ "protein": {"light": "100–250", "days": "12–15"},
102
+ "carb": {"light": "100–250", "days": "17–19"},
103
+ },
104
+ "scenedesmus sp.": {
105
+ "biomass": {"light": "50–250", "days": "12–25"},
106
+ "lipid": {"light": "High light intensity (stress)", "days": "12–25"},
107
+ "protein": {"light": "50–250", "days": "12–20"},
108
+ "carb": {"light": "50–250", "days": "12–25"},
109
+ },
110
+ }
111
+
112
+ def _normalize_species_label(s: str) -> str:
113
+ if s is None: return ""
114
+ s0 = str(s).strip().lower()
115
+ s1 = re.sub(r"[_\-]+", " ", s0).replace(" ", " ").strip()
116
+ s2 = s1.replace(" .", ".").replace(". ", ". ")
117
+ alias = {
118
+ "a platensis": "a. platensis", "a.platensis": "a. platensis", "arthrospira platensis": "a. platensis",
119
+ "c pyrenoidosa": "c. pyrenoidosa", "c.pyrenoidosa": "c. pyrenoidosa", "chlorella pyrenoidosa": "c. pyrenoidosa",
120
+ "c sorokiniana": "c. sorokiniana", "c.sorokiniana": "c. sorokiniana",
121
+ "c variabilis": "c. variabilis", "c.variabilis": "c. variabilis",
122
+ "c vulgaris": "c. vulgaris", "c.vulgaris": "c. vulgaris", "chlorella vulgaris": "c. vulgaris",
123
+ "c zofingiensis": "c. zofingiensis", "c.zofingiensis": "c. zofingiensis",
124
+ "h pluvialis": "h. pluvialis", "h.pluvialis": "h. pluvialis", "haematococcus pluvialis": "h. pluvialis",
125
+ "p purpureum": "p. purpureum", "p.purpureum": "p. purpureum", "porphyridium purpureum": "p. purpureum",
126
+ "scenedesmus": "scenedesmus sp.", "scenedesmus sp": "scenedesmus sp.", "scenedesmus sp.": "scenedesmus sp.",
127
+ }
128
+ return alias.get(s2, s2)
129
+
130
+ def _format_suggestion_md(species: str, target: str) -> str:
131
+ sp = _normalize_species_label(species)
132
+ tg = (target or "").strip().lower()
133
+ data = SPECIES_SUGGESTIONS.get(sp, {}).get(tg)
134
+ if not data:
135
+ return f"> ℹ️ No curated suggestion for **{species}** and **{target}**."
136
+ return (
137
+ f"### 💡 Suggested conditions for *{sp}* → *{tg}*\n"
138
+ f"**Light intensity:** {data['light']}  |  **Days:** {data['days']}"
139
+ )
140
+
141
+ def update_suggestion_panel(target, species):
142
+ return _format_suggestion_md(species, target)
143
+
144
+ # -----------------------------
145
+ # Load and normalize real data
146
+ # -----------------------------
147
+ df_raw = pd.read_csv(RAW_PATH)
148
+ df_raw.columns = (
149
+ df_raw.columns.str.strip()
150
+ .str.lower()
151
+ .str.replace("[^0-9a-zA-Z]+", "_", regex=True)
152
+ )
153
+
154
+ FEATURES = ["species","media","light","expo_day","expo_night","_c","ph","days"]
155
+ CATEGORICAL = ["species","media"]
156
+ NUM_CYCLE_FIRST = ["light"]
157
+ NUM_PLAIN = ["expo_day","expo_night","_c","ph","days"]
158
+ TARGETS = ["biomass","lipid","protein","carb"]
159
+
160
+ # Normalize for encoders
161
+ df_enc = df_raw.copy()
162
+ for col in CATEGORICAL:
163
+ if col in df_enc.columns:
164
+ df_enc[col] = df_enc[col].map(normalize_str)
165
+
166
+ # Fit encoders on CSV categories
167
+ encoders, value_lists = {}, {}
168
+ for col in CATEGORICAL:
169
+ le = LabelEncoder()
170
+ vals = df_enc[col].astype(str).fillna("nan")
171
+ le.fit(vals)
172
+ encoders[col] = le
173
+ value_lists[col] = sorted(set(vals) - {"nan"})
174
+
175
+ # Prepare numerics for imputer fit
176
+ for c in NUM_CYCLE_FIRST:
177
+ if c in df_enc.columns:
178
+ df_enc[c] = coerce_numeric(df_enc[c], "cycle_first")
179
+ for c in NUM_PLAIN:
180
+ if c in df_enc.columns:
181
+ df_enc[c] = coerce_numeric(df_enc[c], "float")
182
+
183
+ def encode_frame(df_like: pd.DataFrame) -> pd.DataFrame:
184
+ X = pd.DataFrame()
185
+ for col in CATEGORICAL:
186
+ if col in df_like.columns:
187
+ X[col] = df_like[col].map(normalize_str)
188
+ X[col] = encoders[col].transform(X[col].astype(str).fillna("nan"))
189
+ for c in NUM_CYCLE_FIRST:
190
+ if c in df_like.columns:
191
+ X[c] = coerce_numeric(df_like[c], "cycle_first")
192
+ for c in NUM_PLAIN:
193
+ if c in df_like.columns:
194
+ X[c] = coerce_numeric(df_like[c], "float")
195
+ for c in FEATURES:
196
+ if c not in X.columns:
197
+ X[c] = np.nan
198
+ return X[FEATURES]
199
+
200
+ X_for_imputer = encode_frame(df_raw)
201
+ imputer = SimpleImputer(strategy="median").fit(X_for_imputer)
202
+
203
+ # -----------------------------
204
+ # Species-media vocab + aliases
205
+ # -----------------------------
206
+ ALLOWED_PAIRS_ALIAS = {
207
+ "a.platensis": ["zarrouks", "bg 11"],
208
+ "c sorokiniana": ["tap", "bg 11"],
209
+ "c vulgaris": ["bg 11", "bbm"],
210
+ "scenedesmus": ["bg 11", "bbm"],
211
+ "p purpureum": ["artificial sea water", "erdseirber and bold nv", "f2"],
212
+ "h pluvalis": ["bg 11"],
213
+ "c pyreniidosa": ["bg 11", "bbm", "selenite media"],
214
+ "c zofingensis": ["bg 11", "bbm", "tap"],
215
+ "c variabilis": ["bg 11", "zorrouks", "tap"],
216
+ }
217
+ SPECIES_ALIASES = {
218
+ "a.platensis": ["arthrospira platensis", "spirulina platensis", "a. platensis"],
219
+ "c sorokiniana": ["chlorella sorokiniana", "c. sorokiniana"],
220
+ "c vulgaris": ["chlorella vulgaris", "c. vulgaris"],
221
+ "scenedesmus": ["scenedesmus", "scenedesmus sp.", "desmodesmus sp."],
222
+ "p purpureum": ["porphyridium purpureum", "p. purpureum"],
223
+ "h pluvalis": ["haematococcus pluvialis", "h. pluvialis", "h pluvalis"],
224
+ "c pyreniidosa": ["chlorella pyrenoidosa", "c. pyrenoidosa", "c pyreniidosa"],
225
+ "c zofingensis": ["chromochloris zofingiensis", "c. zofingiensis", "chlorella zofingiensis"],
226
+ "c variabilis": ["chlorella variabilis", "c. variabilis"],
227
+ }
228
+ MEDIA_ALIASES = {
229
+ "zarrouks": ["zarrouk's", "zarrouks", "zarrouk"],
230
+ "zorrouks": ["zarrouk's", "zarrouks", "zarrouk"],
231
+ "bg 11": ["bg 11", "bg-11", "bg11"],
232
+ "bbm": ["bbm", "bold's basal medium", "bold basal medium", "bolds basal medium"],
233
+ "tap": ["tap", "tap water"],
234
+ "artificial sea water": ["artificial sea water", "artificial seawater", "asw"],
235
+ "erdseirber and bold nv": ["erdschreiber and bold nv", "erdschreiber", "bold nv", "bold's nv", "erdschreiber & bold nv"],
236
+ "f2": ["f/2", "guillard f/2", "f2"],
237
+ "selenite media": ["selenite medium", "selenite media"],
238
+ }
239
+
240
+ def match_to_vocab(name: str, vocab: list[str], aliases: dict[str, list[str]], cutoff=0.6):
241
+ n = normalize_str(name)
242
+ if n in vocab: return n
243
+ for syn in aliases.get(n, []):
244
+ sn = normalize_str(syn)
245
+ if sn in vocab: return sn
246
+ hit = get_close_matches(n, vocab, n=1, cutoff=cutoff)
247
+ return hit[0] if hit else None
248
+
249
+ species_vocab = value_lists["species"]
250
+ media_vocab = value_lists["media"]
251
+
252
+ ALLOWED_PAIRS = {}
253
+ for s_alias, m_aliases in ALLOWED_PAIRS_ALIAS.items():
254
+ s_canon = match_to_vocab(s_alias, species_vocab, SPECIES_ALIASES)
255
+ if not s_canon:
256
+ continue
257
+ canon_media = []
258
+ for m_alias in m_aliases:
259
+ m_canon = match_to_vocab(m_alias, media_vocab, MEDIA_ALIASES)
260
+ if m_canon:
261
+ canon_media.append(m_canon)
262
+ if canon_media:
263
+ ALLOWED_PAIRS[s_canon] = sorted(set(canon_media))
264
+
265
+ if not ALLOWED_PAIRS:
266
+ # Fallback: allow any species-media (warn in UI)
267
+ ALLOWED_PAIRS = {s: sorted(set(media_vocab)) for s in species_vocab}
268
+ WARN_ALL = True
269
+ else:
270
+ WARN_ALL = False
271
+
272
+ def allowed_media_for(species_norm):
273
+ return ALLOWED_PAIRS.get(species_norm, [])
274
+
275
+ # -----------------------------
276
+ # Model loader
277
+ # -----------------------------
278
+ def get_augmented_path(target: str):
279
+ p200 = HERE / f"augmented_{target}_200k.csv"
280
+ p20 = HERE / f"augmented_{target}_20k.csv"
281
+ return p200 if p200.exists() else (p20 if p20.exists() else None)
282
+
283
+ def load_or_train_catboost(target: str) -> CatBoostRegressor:
284
+ model_path = MODEL_DIR / f"{target}.cbm"
285
+ if model_path.exists():
286
+ model = CatBoostRegressor()
287
+ model.load_model(str(model_path))
288
+ return model
289
+
290
+ aug_path = get_augmented_path(target)
291
+ if aug_path is None:
292
+ # Fallback: light train on real data if augmented not uploaded
293
+ if target not in df_raw.columns:
294
+ raise FileNotFoundError(
295
+ f"No model '{model_path.name}' and no column '{target}' in ai_al.csv."
296
+ )
297
+ y = df_raw[target].apply(extract_first_float).astype(float)
298
+ if y.dropna().empty:
299
+ raise FileNotFoundError(f"No model and no usable labels for target '{target}'.")
300
+ X = X_for_imputer
301
+ model = CatBoostRegressor(
302
+ iterations=400, depth=8, learning_rate=0.06,
303
+ loss_function="RMSE", random_seed=42, verbose=False
304
+ )
305
+ model.fit(X, y)
306
+ model.save_model(str(model_path))
307
+ return model
308
+
309
+ df_aug = pd.read_csv(aug_path)
310
+ X_aug = df_aug.drop(columns=[target])
311
+ y_aug = df_aug[target].astype(float)
312
+ model = CatBoostRegressor(
313
+ iterations=700, depth=8, learning_rate=0.06,
314
+ loss_function="RMSE", random_seed=42, verbose=False
315
+ )
316
+ model.fit(X_aug, y_aug)
317
+ model.save_model(str(model_path))
318
+ return model
319
+
320
+ _models = {}
321
+ def get_model(target: str):
322
+ if target not in _models:
323
+ _models[target] = load_or_train_catboost(target)
324
+ return _models[target]
325
+
326
+ # -----------------------------
327
+ # Optional DOI database
328
+ # -----------------------------
329
+ try:
330
+ if DOI_PATH.exists():
331
+ df_doi_raw = pd.read_csv(DOI_PATH)
332
+ else:
333
+ raise FileNotFoundError("doi.csv not found")
334
+
335
+ df_doi_raw.columns = (
336
+ df_doi_raw.columns.str.strip()
337
+ .str.lower()
338
+ .str.replace("[^0-9a-zA-Z]+", "_", regex=True)
339
+ )
340
+ for c in ["species", "media"]:
341
+ if c in df_doi_raw.columns: df_doi_raw[c] = df_doi_raw[c].map(normalize_str)
342
+ if "light" in df_doi_raw.columns: df_doi_raw["light"] = coerce_numeric(df_doi_raw["light"], "cycle_first")
343
+ for c in ["expo_day","expo_night","_c","ph","days"]:
344
+ if c in df_doi_raw.columns: df_doi_raw[c] = coerce_numeric(df_doi_raw[c], "float")
345
+
346
+ doi_col_candidates = [c for c in df_doi_raw.columns if c.lower() in {"doi","doi_id","reference","url","link"}]
347
+ DOI_COL = doi_col_candidates[0] if doi_col_candidates else None
348
+
349
+ NUMERIC_COLS = ["light","expo_day","expo_night","_c","ph","days"]
350
+ scales = {}
351
+ for col in NUMERIC_COLS:
352
+ if col not in df_doi_raw.columns: continue
353
+ v = pd.to_numeric(df_doi_raw[col], errors="coerce").dropna()
354
+ if len(v) >= 4:
355
+ lo, hi = np.percentile(v, [5,95]); span = max(1e-6, hi - lo)
356
+ elif len(v) > 1:
357
+ span = max(1e-6, v.max() - v.min())
358
+ else:
359
+ span = 1.0
360
+ scales[col] = span
361
+
362
+ def _media_similarity(a, b):
363
+ a = normalize_str(a); b = normalize_str(b)
364
+ def canon(m):
365
+ if m in MEDIA_ALIASES: return m
366
+ for k, syns in MEDIA_ALIASES.items():
367
+ ns = [normalize_str(s) for s in syns]
368
+ if m == k or m in ns: return k
369
+ return m
370
+ from difflib import SequenceMatcher
371
+ ca, cb = canon(a), canon(b)
372
+ return 1.0 if ca == cb else SequenceMatcher(None, ca, cb).ratio()
373
+
374
+ def _doi_url(x):
375
+ if x is None or (isinstance(x, float) and np.isnan(x)): return None
376
+ s = str(x).strip()
377
+ if s.startswith("http://") or s.startswith("https://"): return s
378
+ s = s.lower().replace("doi:", "").strip()
379
+ return f"https://doi.org/{s}"
380
+
381
+ def _closest_doi(species, media, light, expo_day, expo_night, temp_c, ph, days, topk=3):
382
+ if df_doi_raw.empty: return "> ℹ️ doi.csv is empty."
383
+ s_key = _normalize_species_label(normalize_str(species))
384
+ df_cand = df_doi_raw[df_doi_raw["species"] == s_key]
385
+ if df_cand.empty:
386
+ sp_unique = df_doi_raw["species"].dropna().unique().tolist()
387
+ best = get_close_matches(s_key, sp_unique, n=1, cutoff=0.6)
388
+ df_cand = df_doi_raw[df_doi_raw["species"] == (best[0] if best else s_key)]
389
+
390
+ q = {
391
+ "light": parse_cycle_first(light),
392
+ "expo_day": extract_first_float(expo_day),
393
+ "expo_night": extract_first_float(expo_night),
394
+ "_c": extract_first_float(temp_c),
395
+ "ph": extract_first_float(ph),
396
+ "days": extract_first_float(days),
397
+ }
398
+
399
+ rows = []
400
+ for _, r in df_cand.iterrows():
401
+ sim = _media_similarity(media, r.get("media", "")); media_penalty = (1.0 - sim) * 0.5
402
+ dist = 0.0; denom = 0
403
+ for col in ["light","expo_day","expo_night","_c","ph","days"]:
404
+ if col in df_cand.columns:
405
+ rv, qv = r.get(col, np.nan), q[col]
406
+ if pd.notna(rv) and pd.notna(qv):
407
+ dist += abs(float(qv) - float(rv)) / scales.get(col, 1.0); denom += 1
408
+ dist = dist/denom if denom>0 else 1.0
409
+ rows.append((media_penalty + dist, r))
410
+ if not rows: return "> ℹ️ No comparable rows in doi.csv."
411
+ rows.sort(key=lambda x: x[0]); top = rows[:topk]
412
+
413
+ md = "### 📚 Closest DOI matches\n"
414
+ for rank, (score, r) in enumerate(top, 1):
415
+ sim_pct = max(0.0, min(100.0, 100.0 * np.exp(-score)))
416
+ doi_link = _doi_url(r.get(DOI_COL)) if DOI_COL else None
417
+ head = f"**{rank}. {r.get('species','?')} — {r.get('media','?')}** · Similarity **{sim_pct:.1f}%**"
418
+ if doi_link: head += f" · [DOI]({doi_link})"
419
+ md += head + "\n"
420
+ md += (f"• Light: {r.get('light','NA')} · Day: {r.get('expo_day','NA')} · Night: {r.get('expo_night','NA')} · "
421
+ f"T(°C): {r.get('_c','NA')} · pH: {r.get('ph','NA')} · Days: {r.get('days','NA')}\n")
422
+ return md
423
+
424
+ DOI_READY = True
425
+ except Exception as e:
426
+ DOI_READY = False
427
+ def _closest_doi(*args, **kwargs): return f"> ⚠️ DOI lookup unavailable: {e}"
428
+
429
+ # -----------------------------
430
+ # Preprocess + validate pair
431
+ # -----------------------------
432
+ def preprocess_row(species, media, light, expo_day, expo_night, temp_c, ph, days):
433
+ species_n = normalize_str(species); media_n = normalize_str(media)
434
+ if species_n not in ALLOWED_PAIRS:
435
+ raise ValueError(f"Species '{species}' not allowed.")
436
+ if media_n not in ALLOWED_PAIRS[species_n]:
437
+ raise ValueError(f"Media '{media}' not allowed for species '{species}'.")
438
+ if species_n not in value_lists["species"]:
439
+ raise ValueError(f"Species '{species}' not present in training encodings.")
440
+ if media_n not in value_lists["media"]:
441
+ raise ValueError(f"Media '{media}' not present in training encodings.")
442
+
443
+ row = pd.DataFrame([{
444
+ "species": species_n, "media": media_n, "light": light,
445
+ "expo_day": expo_day, "expo_night": expo_night,
446
+ "_c": temp_c, "ph": ph, "days": days
447
+ }], columns=FEATURES)
448
+
449
+ for col in CATEGORICAL:
450
+ row[col] = encoders[col].transform([row.loc[0, col]])[0]
451
+
452
+ row["light"] = row["light"].apply(parse_cycle_first)
453
+ for c in ["expo_day","expo_night","_c","ph","days"]:
454
+ row[c] = row[c].apply(extract_first_float)
455
+
456
+ row = pd.DataFrame(imputer.transform(row[FEATURES]), columns=FEATURES)
457
+ return row
458
+
459
+ # -----------------------------
460
+ # Uncertainty engine (KNN on augmented)
461
+ # -----------------------------
462
+ _AUG = {} # target -> (X_aug_np, y_aug_np)
463
+ _KNN = {} # target -> fitted NearestNeighbors
464
+ _PERC = {} # target -> {feature: (p05, p95)}
465
+ K_NEI = 200
466
+ Q_LO, Q_HI = 0.10, 0.90
467
+
468
+ def _load_aug_and_knn(target: str):
469
+ if target in _KNN: return
470
+ aug_path = get_augmented_path(target)
471
+ if aug_path is None:
472
+ raise FileNotFoundError(
473
+ f"Missing augmented file for '{target}'. Upload 'augmented_{target}_200k.csv' or '_20k.csv'."
474
+ )
475
+ df_aug = pd.read_csv(aug_path)
476
+ for c in FEATURES:
477
+ if c not in df_aug.columns: df_aug[c] = np.nan
478
+ X_aug = df_aug[FEATURES].copy()
479
+ X_aug_imp = pd.DataFrame(imputer.transform(X_aug), columns=FEATURES)
480
+ y_aug = df_aug[target].astype(float).values
481
+ X_np = X_aug_imp.values.astype(float)
482
+
483
+ perc = {}
484
+ for j, c in enumerate(FEATURES):
485
+ colv = X_np[:, j]
486
+ perc[c] = (np.nanpercentile(colv, 5), np.nanpercentile(colv, 95))
487
+
488
+ nn = NearestNeighbors(n_neighbors=min(K_NEI, len(X_np)), algorithm="auto")
489
+ nn.fit(X_np)
490
+ _AUG[target] = (X_np, y_aug)
491
+ _KNN[target] = nn
492
+ _PERC[target] = perc
493
+
494
+ def _local_interval(target: str, X_query: np.ndarray):
495
+ _load_aug_and_knn(target)
496
+ X_aug, y_aug = _AUG[target]
497
+ nn = _KNN[target]
498
+ k_use = min(K_NEI, len(X_aug))
499
+ dists, idxs = nn.kneighbors(X_query, n_neighbors=k_use, return_distance=True)
500
+ qlo = np.quantile(y_aug[idxs], Q_LO, axis=1)
501
+ qhi = np.quantile(y_aug[idxs], Q_HI, axis=1)
502
+ return qlo, qhi
503
+
504
+ # -----------------------------
505
+ # Predict + Uncertainty + Plot
506
+ # -----------------------------
507
+ def predict_and_plot_ui(target, species, media, light, expo_day, expo_night, temp_c, ph, days, plot_var):
508
+ try:
509
+ # 1) preprocess single row
510
+ X_one = preprocess_row(species, media, light, expo_day, expo_night, temp_c, ph, days)
511
+ model = get_model(target)
512
+ yhat = float(model.predict(X_one)[0])
513
+
514
+ # 2) local uncertainty at current point
515
+ qlo, qhi = _local_interval(target, X_one.values)
516
+ lo, hi = float(qlo[0]), float(qhi[0])
517
+
518
+ # 3) response curve vs selected variable
519
+ plot_var = (plot_var or "light").strip().lower()
520
+ if plot_var not in FEATURES:
521
+ plot_var = "light"
522
+ j = FEATURES.index(plot_var)
523
+ _load_aug_and_knn(target)
524
+ p05, p95 = _PERC[target][plot_var]
525
+ xs = np.linspace(p05, p95, 60)
526
+
527
+ X_grid = np.repeat(X_one.values, len(xs), axis=0)
528
+ X_grid[:, j] = xs
529
+ y_grid = model.predict(pd.DataFrame(X_grid, columns=FEATURES)).astype(float)
530
+ qlo_g, qhi_g = _local_interval(target, X_grid)
531
+
532
+ # 4) combined plot
533
+ fig, ax = plt.subplots(figsize=(7.0, 4.2))
534
+ ax.plot(xs, y_grid, label="Predicted mean")
535
+ ax.fill_between(xs, qlo_g, qhi_g, alpha=0.25, label=f"Local {int((Q_HI-Q_LO)*100)}% band")
536
+ x0 = float(X_one.values[0, j])
537
+ ax.axvline(x0, linestyle="--", alpha=0.6)
538
+ ax.scatter([x0], [yhat], zorder=3)
539
+ ax.set_xlabel(plot_var)
540
+ ax.set_ylabel(target)
541
+ ax.set_title(f"{target} vs {plot_var} (others fixed)")
542
+ ax.legend(loc="best")
543
+ plt.tight_layout()
544
+
545
+ md = (
546
+ f"### Prediction\n"
547
+ f"**{target}** = **{yhat:.3f}** \n"
548
+ f"Local {int((Q_HI-Q_LO)*100)}% interval: **[{lo:.3f}, {hi:.3f}]**"
549
+ )
550
+ return md, fig
551
+ except Exception as e:
552
+ fig, ax = plt.subplots(figsize=(6,3)); ax.axis("off"); plt.tight_layout()
553
+ return f"Error: {e}", fig
554
+
555
+ def doi_matches_ui(target, species, media, light, expo_day, expo_night, temp_c, ph, days):
556
+ return _closest_doi(species, media, light, expo_day, expo_night, temp_c, ph, days, topk=3)
557
+
558
+ def update_suggestion(species, target):
559
+ return _format_suggestion_md(species, target)
560
+
561
+ # -----------------------------
562
+ # UI — layout
563
+ # -----------------------------
564
+ theme = Soft(primary_hue="emerald", neutral_hue="slate", radius_size="lg", spacing_size="sm")
565
+ CSS = """
566
+ .card { border: 1px solid var(--border-color-primary); border-radius: 12px; padding: 14px; background: var(--block-background-fill); }
567
+ .small { font-size: 0.92rem; opacity: 0.95; }
568
+ """
569
+
570
+ def update_media(species):
571
+ s = normalize_str(species) if species else None
572
+ choices = allowed_media_for(s) if s else []
573
+ value = choices[0] if choices else None
574
+ return gr.update(choices=choices, value=value)
575
+
576
+ allowed_species = sorted(ALLOWED_PAIRS.keys())
577
+ first_species = allowed_species[0] if allowed_species else None
578
+ first_media_choices = allowed_media_for(first_species) if first_species else []
579
+ first_media = first_media_choices[0] if first_media_choices else None
580
+
581
+ with gr.Blocks(title="Algae Yield Predictor", theme=theme, css=CSS) as demo:
582
+ gr.Markdown(
583
+ f"<h1>Algae Yield Predictor</h1>"
584
+ f"<div class='small'>Predict <b>biomass / lipid / protein / carbohydrate</b> and visualize local uncertainty."
585
+ f"{' &nbsp;<em>(All species–media pairs enabled; CSV match not found.)</em>' if WARN_ALL else ''}"
586
+ f"{'' if DOI_PATH.exists() and DOI_READY else ' &nbsp;<em>(DOI file missing or lacks a doi column.)</em>'}"
587
+ f"</div>",
588
+ elem_classes=["card"]
589
+ )
590
+
591
+ with gr.Row():
592
+ with gr.Column(scale=6):
593
+ with gr.Group(elem_classes=["card"]):
594
+ gr.Markdown("### Inputs")
595
+ target_dd = gr.Dropdown(choices=TARGETS, value="biomass", label="Target", info="Choose outcome to predict")
596
+ with gr.Row():
597
+ species_dd = gr.Dropdown(choices=allowed_species, value=first_species, label="Species", info="Only curated species")
598
+ media_dd = gr.Dropdown(choices=first_media_choices, value=first_media, label="Medium", info="Restricted by species")
599
+ gr.Markdown("#### Culture Conditions", elem_classes=["small"])
600
+ with gr.Row():
601
+ light_sl = gr.Slider(10, 400, value=120, step=5, label="Light (μmol·m⁻²·s⁻¹)")
602
+ days_sl = gr.Slider(1, 45, value=14, step=1, label="Days", info="Total culture duration")
603
+ with gr.Row():
604
+ day_sl = gr.Slider(0, 24, value=18, step=1, label="Day Exposure (h)")
605
+ night_sl = gr.Slider(0, 24, value=6, step=1, label="Night Exposure (h)")
606
+ with gr.Row():
607
+ temp_num = gr.Number(value=25, label="Temperature (°C)", precision=1)
608
+ ph_num = gr.Number(value=7.0, label="pH", precision=2)
609
+ with gr.Row():
610
+ plot_var_dd = gr.Dropdown(
611
+ choices=["light","days","expo_day","expo_night","_c","ph"],
612
+ value="light",
613
+ label="Plot variable",
614
+ info="Sweep one input to see response curve with uncertainty band"
615
+ )
616
+ with gr.Row():
617
+ go = gr.Button("Predict + Plot", variant="primary")
618
+ doi_btn = gr.Button("Find Closest DOI Matches", variant="secondary")
619
+
620
+ with gr.Group(elem_classes=["card"]):
621
+ gr.Markdown("### Suggested Conditions")
622
+ suggest_md = gr.Markdown(value=_format_suggestion_md(first_species or "", "biomass"))
623
+
624
+ with gr.Column(scale=6):
625
+ with gr.Group(elem_classes=["card"]):
626
+ pred_md = gr.Markdown("Click <b>Predict + Plot</b> to run.")
627
+ with gr.Group(elem_classes=["card"]):
628
+ gr.Markdown("### Combined Response Plot")
629
+ plot_out = gr.Plot()
630
+ with gr.Group(elem_classes=["card"]):
631
+ gr.Markdown("### Literature (DOI) Matches")
632
+ doi_md = gr.Markdown("Click <b>Find Closest DOI Matches</b> to see references.")
633
+
634
+ # Wiring
635
+ species_dd.change(fn=update_media, inputs=species_dd, outputs=media_dd)
636
+ target_dd.change(update_suggestion_panel, inputs=[target_dd, species_dd], outputs=suggest_md)
637
+ species_dd.change(update_suggestion_panel, inputs=[target_dd, species_dd], outputs=suggest_md)
638
+
639
+ go.click(
640
+ fn=predict_and_plot_ui,
641
+ inputs=[target_dd, species_dd, media_dd, light_sl, day_sl, night_sl, temp_num, ph_num, days_sl, plot_var_dd],
642
+ outputs=[pred_md, plot_out]
643
+ )
644
+ doi_btn.click(
645
+ fn=doi_matches_ui,
646
+ inputs=[target_dd, species_dd, media_dd, light_sl, day_sl, night_sl, temp_num, ph_num, days_sl],
647
+ outputs=doi_md
648
+ )
649
+
650
+ # Spaces will run this automatically; keep main guard for local dev.
651
+ if __name__ == "__main__":
652
+ demo.launch(server_name="0.0.0.0", server_port=7860)