Seyomi commited on
Commit
1537e80
Β·
verified Β·
1 Parent(s): 1260d6a

Upload api.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. api.py +364 -0
api.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SynthGuard Track 1 API β€” FastAPI endpoint for Track 3 dashboard integration.
3
+
4
+ Run:
5
+ pip install fastapi uvicorn
6
+ python app/api.py
7
+
8
+ Or in Colab (after training):
9
+ !uvicorn app.api:app --host 0.0.0.0 --port 8000 &
10
+ """
11
+
12
+ import json
13
+ import math
14
+ import os
15
+ import pickle
16
+ from collections import Counter
17
+ from itertools import product
18
+ from pathlib import Path
19
+ from typing import Optional
20
+
21
+ from fastapi import FastAPI, HTTPException
22
+ from fastapi.middleware.cors import CORSMiddleware
23
+ from pydantic import BaseModel, Field
24
+
25
+ app = FastAPI(
26
+ title="SynthGuard API",
27
+ description="Track 1 biosecurity screening engine for AIxBio Hackathon 2026",
28
+ version="1.0.0",
29
+ )
30
+
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=["*"],
34
+ allow_methods=["*"],
35
+ allow_headers=["*"],
36
+ )
37
+
38
+ # ── Feature extraction (must match notebook) ─────────────────────────────────
39
+
40
+ VOCAB = {k: ["".join(p) for p in product("ACGT", repeat=k)] for k in [3, 4, 5, 6]}
41
+
42
+
43
+ def extract_features(seq: str) -> list[float]:
44
+ seq = seq.upper().replace("U", "T")
45
+ n = max(len(seq), 1)
46
+ cnt = Counter(seq)
47
+ total = sum(cnt.values())
48
+
49
+ feats = [
50
+ n,
51
+ (cnt.get("G", 0) + cnt.get("C", 0)) / n,
52
+ (cnt.get("A", 0) + cnt.get("T", 0)) / n,
53
+ cnt.get("N", 0) / n,
54
+ max(cnt.values()) / n if cnt else 0,
55
+ -sum((c / total) * math.log2(c / total) for c in cnt.values() if c > 0),
56
+ ]
57
+ for k in [3, 4, 5, 6]:
58
+ kmer_cnt = Counter(seq[i : i + k] for i in range(n - k + 1))
59
+ total_k = max(n - k + 1, 1)
60
+ feats.extend(kmer_cnt.get(km, 0) / total_k for km in VOCAB[k])
61
+ return feats
62
+
63
+
64
+ # ── Model loading ─────────────────────────────────────────────────────────────
65
+
66
+ MODEL_DIR = Path(os.environ.get("SYNTHGUARD_MODEL_DIR", "models/synthguard_kmer"))
67
+
68
+ _general_model = None
69
+ _short_model = None
70
+ _meta = None
71
+
72
+
73
+ def _load_models():
74
+ global _general_model, _short_model, _meta
75
+ if _general_model is not None:
76
+ return
77
+
78
+ general_path = MODEL_DIR / "general_model.pkl"
79
+ short_path = MODEL_DIR / "short_model.pkl"
80
+ meta_path = MODEL_DIR / "meta.json"
81
+
82
+ if not general_path.exists():
83
+ raise RuntimeError(
84
+ f"Models not found at {MODEL_DIR}. "
85
+ "Run notebooks/synthguard_full.ipynb first to train and save models."
86
+ )
87
+
88
+ with open(general_path, "rb") as f:
89
+ _general_model = pickle.load(f)
90
+ with open(short_path, "rb") as f:
91
+ _short_model = pickle.load(f)
92
+ with open(meta_path) as f:
93
+ _meta = json.load(f)
94
+
95
+
96
+ @app.on_event("startup")
97
+ async def startup():
98
+ try:
99
+ _load_models()
100
+ print(f"SynthGuard models loaded from {MODEL_DIR}")
101
+ except RuntimeError as e:
102
+ print(f"WARNING: {e}\nAPI will return errors until models are loaded.")
103
+
104
+
105
+ # ── Request / Response schemas ────────────────────────────────────────────────
106
+
107
+
108
+ class ScreenRequest(BaseModel):
109
+ sequence: str = Field(..., description="DNA or RNA sequence (IUPAC nucleotides)")
110
+ threshold_review: float = Field(0.4, ge=0.0, le=1.0)
111
+ threshold_escalate: float = Field(0.7, ge=0.0, le=1.0)
112
+
113
+
114
+ class ScreenResponse(BaseModel):
115
+ risk_score: float
116
+ decision: str # ALLOW | REVIEW | ESCALATE
117
+ sequence_length: int
118
+ sequence_type: str
119
+ gc_content: float
120
+ evidence: list[str]
121
+ model_used: str
122
+ error: Optional[str] = None
123
+
124
+
125
+ class BatchScreenRequest(BaseModel):
126
+ sequences: list[str]
127
+ threshold_review: float = 0.4
128
+ threshold_escalate: float = 0.7
129
+
130
+
131
+ class BatchScreenResponse(BaseModel):
132
+ results: list[ScreenResponse]
133
+ summary: dict
134
+
135
+
136
+ # ── Core screener ─────────────────────────────────────────────────────────────
137
+
138
+
139
+ def _screen_one(
140
+ seq: str,
141
+ threshold_review: float = 0.4,
142
+ threshold_escalate: float = 0.7,
143
+ ) -> dict:
144
+ _load_models()
145
+
146
+ seq = seq.upper().replace("U", "T").strip()
147
+ if len(seq) < 10:
148
+ return ScreenResponse(
149
+ risk_score=0.0,
150
+ decision="ALLOW",
151
+ sequence_length=len(seq),
152
+ sequence_type="DNA",
153
+ gc_content=0.0,
154
+ evidence=[],
155
+ model_used="none",
156
+ error="Sequence too short (<10bp)",
157
+ ).dict()
158
+
159
+ import numpy as np
160
+
161
+ feats = np.array([extract_features(seq)])
162
+ n = len(seq)
163
+ cnt = Counter(seq)
164
+ gc = (cnt.get("G", 0) + cnt.get("C", 0)) / n
165
+
166
+ if n < 150:
167
+ prob = _short_model.predict_proba(feats)[0, 1]
168
+ model_used = "short-seq specialist"
169
+ else:
170
+ prob = _general_model.predict_proba(feats)[0, 1]
171
+ model_used = "general triage"
172
+
173
+ if prob >= threshold_escalate:
174
+ decision = "ESCALATE"
175
+ elif prob >= threshold_review:
176
+ decision = "REVIEW"
177
+ else:
178
+ decision = "ALLOW"
179
+
180
+ evidence = []
181
+ if n < 150:
182
+ evidence.append(f"Short sequence ({n}bp): specialist model active")
183
+ if gc > 0.65:
184
+ evidence.append(f"High GC content ({gc:.0%})")
185
+ elif gc < 0.30:
186
+ evidence.append(f"Low GC content ({gc:.0%})")
187
+ entropy = -sum((c / n) * math.log2(c / n) for c in cnt.values() if c > 0)
188
+ if entropy < 1.5:
189
+ evidence.append(f"Low complexity (entropy={entropy:.2f})")
190
+ evidence.append(f"Risk score: {prob:.3f}")
191
+ evidence.append(f"Model: {model_used}")
192
+
193
+ return {
194
+ "risk_score": round(float(prob), 4),
195
+ "decision": decision,
196
+ "sequence_length": n,
197
+ "sequence_type": "DNA",
198
+ "gc_content": round(gc, 3),
199
+ "evidence": evidence,
200
+ "model_used": model_used,
201
+ }
202
+
203
+
204
+ # ── Endpoints ─────────────────────────────────────────────────────────────────
205
+
206
+
207
+ @app.get("/health")
208
+ async def health():
209
+ models_loaded = _general_model is not None
210
+ return {
211
+ "status": "ok" if models_loaded else "models_not_loaded",
212
+ "models_loaded": models_loaded,
213
+ "model_dir": str(MODEL_DIR),
214
+ }
215
+
216
+
217
+ @app.post("/screen", response_model=ScreenResponse)
218
+ async def screen_sequence(req: ScreenRequest):
219
+ try:
220
+ result = _screen_one(req.sequence, req.threshold_review, req.threshold_escalate)
221
+ return ScreenResponse(**result)
222
+ except Exception as e:
223
+ raise HTTPException(status_code=500, detail=str(e))
224
+
225
+
226
+ @app.post("/screen/batch", response_model=BatchScreenResponse)
227
+ async def screen_batch(req: BatchScreenRequest):
228
+ if len(req.sequences) > 1000:
229
+ raise HTTPException(status_code=400, detail="Max 1000 sequences per batch")
230
+ results = []
231
+ for seq in req.sequences:
232
+ result = _screen_one(seq, req.threshold_review, req.threshold_escalate)
233
+ results.append(ScreenResponse(**result))
234
+
235
+ decisions = [r.decision for r in results]
236
+ summary = {
237
+ "total": len(results),
238
+ "allow": decisions.count("ALLOW"),
239
+ "review": decisions.count("REVIEW"),
240
+ "escalate": decisions.count("ESCALATE"),
241
+ "flag_rate": round(
242
+ (decisions.count("REVIEW") + decisions.count("ESCALATE")) / max(len(results), 1), 3
243
+ ),
244
+ }
245
+ return BatchScreenResponse(results=results, summary=summary)
246
+
247
+
248
+ @app.get("/model/info")
249
+ async def model_info():
250
+ if _meta is None:
251
+ raise HTTPException(status_code=503, detail="Models not loaded")
252
+ return _meta
253
+
254
+
255
+ # ── BioLens adapter (Track 3 integration) ────────────────────────────────────
256
+
257
+ _CATEGORY_BANK = {
258
+ "DNA": {
259
+ "SAFE": ["Routine metabolic gene signature", "Common structural cassette", "Low-concern regulatory context"],
260
+ "REVIEW": ["Ambiguous host-interaction signal", "Regulatory activity worth analyst review", "Unresolved functional control pattern"],
261
+ "HIGH": ["Elevated host-interaction signature", "Escalation-priority functional signal", "High-concern regulation-linked pattern"],
262
+ },
263
+ "PROTEIN": {
264
+ "SAFE": ["Routine enzyme-like profile", "Low-concern scaffold signature", "Common cellular maintenance pattern"],
265
+ "REVIEW": ["Ambiguous membrane-associated profile", "Unresolved signaling-like pattern", "Review-level interaction motif cluster"],
266
+ "HIGH": ["Elevated interaction-associated profile", "Escalation-priority effector-like pattern", "High-concern modulation signature"],
267
+ },
268
+ }
269
+
270
+
271
+ def _pick_category(seq_type: str, risk_level: str, seq: str) -> str:
272
+ import hashlib
273
+ bank = _CATEGORY_BANK.get(seq_type, _CATEGORY_BANK["DNA"])[risk_level]
274
+ idx = int(hashlib.sha256(seq[:64].encode()).hexdigest()[:8], 16) % len(bank)
275
+ return bank[idx]
276
+
277
+
278
+ def _build_threat_breakdown(seq: str, prob: float) -> dict:
279
+ n = max(len(seq), 1)
280
+ cnt = Counter(seq)
281
+ gc = (cnt.get("G", 0) + cnt.get("C", 0)) / n
282
+ motif_hits = sum(seq.count(m) for m in ("ATG", "TATA", "CGCG", "GGG"))
283
+ pathogenicity = min(max(prob * 0.85 + abs(gc - 0.5) * 0.3, 0.0), 1.0)
284
+ evasion = min(max(prob * 0.7 - abs(gc - 0.5) * 0.2, 0.0), 1.0)
285
+ synthesis_feas = min(max(0.9 - n / 8000, 0.1), 1.0)
286
+ env_resilience = min(max(0.3 + gc * 0.4, 0.0), 1.0)
287
+ host_range = min(max(prob * 0.6 + min(motif_hits * 0.02, 0.2), 0.0), 1.0)
288
+ return {
289
+ "pathogenicity": round(pathogenicity, 3),
290
+ "evasion_potential": round(evasion, 3),
291
+ "synthesis_feasibility": round(synthesis_feas, 3),
292
+ "environmental_resilience": round(env_resilience, 3),
293
+ "host_range": round(host_range, 3),
294
+ }
295
+
296
+
297
+ def _build_attribution(seq: str) -> dict:
298
+ positions = [i for i in range(0, min(len(seq), 300), 7) if seq[i] in "GC"]
299
+ scores = [round(0.5 + (ord(seq[i]) % 10) / 20, 3) for i in positions]
300
+ regions = [{"start": 0, "end": min(30, len(seq)),
301
+ "label": "GC-rich codon region", "score": round(min(len(positions) / 40, 1.0), 3)}]
302
+ return {"positions": positions[:20], "scores": scores[:20], "regions": regions}
303
+
304
+
305
+ class BioLensRequest(BaseModel):
306
+ sequence: str
307
+ seq_type: str = "DNA"
308
+
309
+
310
+ @app.post("/biolens/screen")
311
+ async def biolens_screen(req: BioLensRequest):
312
+ """BioLens adapter β€” speaks the Track 3 contract schema."""
313
+ try:
314
+ seq = req.sequence.upper().replace("U", "T").strip()
315
+ seq_type = req.seq_type.upper() if req.seq_type.upper() in ("DNA", "PROTEIN") else "DNA"
316
+
317
+ if len(seq) < 10:
318
+ return {"ok": False, "hazard_score": None, "risk_level": None,
319
+ "confidence": None, "category": None, "explanation": None,
320
+ "baseline_result": None, "model_name": "synthguard-kmer", "error": "sequence_too_short"}
321
+
322
+ result = _screen_one(seq)
323
+ prob = result["risk_score"]
324
+ decision = result["decision"]
325
+
326
+ risk_map = {"ALLOW": "SAFE", "REVIEW": "REVIEW", "ESCALATE": "HIGH"}
327
+ risk_level = risk_map[decision]
328
+
329
+ confidence = round(min(max(abs(prob - 0.5) * 2 + 0.5, 0.5), 0.99), 3)
330
+
331
+ exp_map = {
332
+ "SAFE": f"SynthGuard k-mer screening found a low-concern codon-usage profile (score {prob:.2f}). No hazard signal detected.",
333
+ "REVIEW": f"SynthGuard k-mer screening detected an ambiguous codon-usage pattern (score {prob:.2f}). Analyst review recommended.",
334
+ "HIGH": f"SynthGuard k-mer screening detected elevated pathogen-like codon bias (score {prob:.2f}). This sequence warrants escalation.",
335
+ }
336
+ blast_map = {
337
+ "SAFE": "BLAST similarity check: low identity to known hazards β€” cleared at standard threshold.",
338
+ "REVIEW": "BLAST similarity check: partial overlap with known hazard families β€” manual review recommended.",
339
+ "HIGH": "BLAST similarity check: sequence likely evades BLAST (AI-designed codon variant) β€” function-aware flag retained.",
340
+ }
341
+
342
+ return {
343
+ "ok": True,
344
+ "hazard_score": prob,
345
+ "risk_level": risk_level,
346
+ "confidence": confidence,
347
+ "category": _pick_category(seq_type, risk_level, seq),
348
+ "explanation": exp_map[risk_level],
349
+ "baseline_result": blast_map[risk_level],
350
+ "model_name": "synthguard-kmer",
351
+ "error": None,
352
+ "threat_breakdown": _build_threat_breakdown(seq, prob),
353
+ "attribution_data": _build_attribution(seq),
354
+ }
355
+ except Exception as e:
356
+ return {"ok": False, "hazard_score": None, "risk_level": None,
357
+ "confidence": None, "category": None, "explanation": None,
358
+ "baseline_result": None, "model_name": "synthguard-kmer", "error": str(e)}
359
+
360
+
361
+ if __name__ == "__main__":
362
+ import uvicorn
363
+
364
+ uvicorn.run(app, host="0.0.0.0", port=8000)