Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, HTTPException, Query, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from .settings import settings | |
| from .deps import index, face, get_hf_token, build_agent_with_token | |
| from .models.face import ( | |
| EnrollResp, IdentifyReq, IdentifyResp, IdentifyHit, | |
| IdentifyManyReq, IdentifyManyResp, FaceDet, | |
| ) | |
| from .models.query import QueryReq, QueryResp | |
| from .services.aggregator import aggregate_by_user | |
| from .services.face_service import imdecode | |
| import numpy as np, uuid, cv2, os, io, zipfile, glob, shutil | |
| app = FastAPI(title="Realtime BI Assistant") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=True, | |
| allow_methods=["*"], allow_headers=["*"], | |
| ) | |
| def root(): | |
| return {"ok": True, "msg": "Backend alive"} | |
| def _decide_identity(agg, threshold: float, margin: float): | |
| if not agg: | |
| return "Unknown", 0.0, 0.0 | |
| best_user, best_score = agg[0] | |
| second = agg[1][1] if len(agg) > 1 else -1.0 | |
| margin_val = best_score - second | |
| if best_score >= threshold and margin_val >= margin and best_user != "Unknown": | |
| return best_user, best_score, margin_val | |
| return "Unknown", best_score, margin_val | |
| def _safe_extract(zf: zipfile.ZipFile, dest: str): | |
| os.makedirs(dest, exist_ok=True) | |
| for member in zf.infolist(): | |
| p = os.path.realpath(os.path.join(dest, member.filename)) | |
| if not p.startswith(os.path.realpath(dest) + os.sep): | |
| continue | |
| if member.is_dir(): | |
| os.makedirs(p, exist_ok=True) | |
| else: | |
| os.makedirs(os.path.dirname(p), exist_ok=True) | |
| with zf.open(member) as src, open(p, "wb") as out: | |
| out.write(src.read()) | |
| def _guess_images_root(tmpdir: str) -> str | None: | |
| pref = os.path.join(tmpdir, "Images") | |
| if os.path.isdir(pref): | |
| return pref | |
| for root, dirs, files in os.walk(tmpdir): | |
| subdirs = [os.path.join(root, d) for d in dirs] | |
| if subdirs and any( | |
| any(fn.lower().endswith((".jpg",".jpeg",".png")) for fn in os.listdir(sd)) | |
| for sd in subdirs | |
| ): | |
| return root | |
| return None | |
| async def enroll_zip(zipfile_upload: UploadFile = File(...)): | |
| """ | |
| Accepts a ZIP with structure: Images/<UserName>/*.jpg|png | |
| Upserts all faces into the local FAISS index under user metadata. | |
| """ | |
| if not zipfile_upload.filename.lower().endswith(".zip"): | |
| raise HTTPException(400, "Please upload a .zip") | |
| raw = await zipfile_upload.read() | |
| tmpdir = os.path.join("/workspace", "upload", uuid.uuid4().hex[:8]) | |
| os.makedirs(tmpdir, exist_ok=True) | |
| try: | |
| with zipfile.ZipFile(io.BytesIO(raw), "r") as zf: | |
| _safe_extract(zf, tmpdir) | |
| root = _guess_images_root(tmpdir) | |
| if not root: | |
| raise HTTPException(400, "Couldn't find 'Images/<UserName>/*' structure in ZIP") | |
| user_dirs = sorted([p for p in glob.glob(os.path.join(root, "*")) if os.path.isdir(p)]) | |
| if not user_dirs: | |
| raise HTTPException(400, "No user folders found under Images/") | |
| total = 0 | |
| enrolled_users = [] | |
| for udir in user_dirs: | |
| user = os.path.basename(udir) | |
| paths = sorted([p for p in glob.glob(os.path.join(udir, "*")) if p.lower().endswith((".jpg",".jpeg",".png"))]) | |
| if not paths: | |
| continue | |
| count_user = 0 | |
| for p in paths: | |
| img = cv2.imdecode(np.fromfile(p, dtype=np.uint8), cv2.IMREAD_COLOR) | |
| if img is None: continue | |
| bbox, emb, det_score = face.embed_best(img) | |
| if emb is None: continue | |
| vec = emb.astype(np.float32) | |
| vec = vec / (np.linalg.norm(vec) + 1e-9) | |
| vid = f"{user}::{uuid.uuid4().hex[:8]}" | |
| index.add_vectors(vecs=np.array([vec]), | |
| metas=[{"user":user,"det_score":float(det_score), "source":"enroll_zip"}], | |
| ids=[vid]) | |
| count_user += 1 | |
| total += 1 | |
| if count_user > 0: | |
| enrolled_users.append(user) | |
| return EnrollResp(users=enrolled_users, total_vectors=total) | |
| finally: | |
| try: | |
| shutil.rmtree(tmpdir, ignore_errors=True) | |
| except Exception: | |
| pass | |
| # ---------- endpoints ---------- | |
| async def upsert_image(user: str = Query(..., description="User label"), | |
| image: UploadFile = File(...)): | |
| raw = await image.read() | |
| img = cv2.imdecode(np.frombuffer(raw, np.uint8), cv2.IMREAD_COLOR) | |
| if img is None: | |
| raise HTTPException(400, "Invalid image file") | |
| bbox, emb, det_score = face.embed_best(img) | |
| if emb is None: | |
| return {"ok": False, "msg": "no face detected"} | |
| vec = emb.astype(np.float32) | |
| vec = vec / (np.linalg.norm(vec) + 1e-9) | |
| vid = f"{user}::{uuid.uuid4().hex[:8]}" | |
| index.add_vectors(vecs=np.array([vec]), | |
| metas=[{"user":user,"det_score":float(det_score)}], | |
| ids=[vid]) | |
| return {"ok": True, "id": vid, "user": user, "det_score": float(det_score)} | |
| async def identify(req: IdentifyReq): | |
| try: | |
| img = imdecode(req.image_b64) | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Bad image_b64") | |
| bbox, emb, det_score = face.embed_best(img) | |
| if emb is None: | |
| return IdentifyResp(decision="NoFace", best_score=0.0, margin=0.0, topk=[], bbox=None) | |
| matches = index.query(emb, top_k=settings.TOPK_DB) | |
| agg = aggregate_by_user(matches) | |
| user, best, margin_val = _decide_identity(agg, settings.THRESHOLD, settings.MARGIN) | |
| topk = [IdentifyHit(user=u, score=s) for u, s in agg[:req.top_k]] | |
| return IdentifyResp(decision=user, best_score=best, margin=margin_val, topk=topk, bbox=bbox) | |
| # ---------- NEW: multi-face endpoint ---------- | |
| async def identify_many(req: IdentifyManyReq): | |
| try: | |
| img = imdecode(req.image_b64) | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Bad image_b64") | |
| faces = face.embed_all(img) | |
| if not faces: | |
| return IdentifyManyResp(detections=[]) | |
| detections: list[FaceDet] = [] | |
| top_k_db = req.top_k_db or settings.TOPK_DB | |
| for bbox, emb, det_score in faces: | |
| matches = index.query(emb, top_k=top_k_db) | |
| agg = aggregate_by_user(matches) | |
| user, best, margin_val = _decide_identity(agg, settings.THRESHOLD, settings.MARGIN) | |
| topk = [IdentifyHit(user=u, score=s) for u, s in agg[:req.top_k]] | |
| detections.append(FaceDet( | |
| bbox=bbox, | |
| decision=user, | |
| best_score=best, | |
| margin=margin_val, | |
| topk=topk | |
| )) | |
| return IdentifyManyResp(detections=detections) | |
| async def query(req: QueryReq, hf_token: str | None = Depends(get_hf_token)): | |
| text = (req.text or "").strip() | |
| if not text: | |
| raise HTTPException(400, "Empty question") | |
| sql_agent = build_agent_with_token(hf_token) | |
| try: | |
| answer_text, meta = sql_agent.ask(req.user_id, text) | |
| citations = [f"sql:{meta['sql']}"] | |
| return QueryResp( | |
| answer_text=answer_text, | |
| citations=citations, | |
| metrics={}, | |
| chart_refs=[], | |
| # uncertainty=0.15 | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Query failed: {e}") | |
| import requests | |
| def llm_health(hf_token: str | None = Depends(get_hf_token)): | |
| tok = hf_token or settings.HF_TOKEN | |
| if not tok: | |
| return {"status": 400, "ok": False, "body": "HF token missing"} | |
| try: | |
| r = requests.post( | |
| "https://router.huggingface.co/v1/chat/completions", | |
| headers={ | |
| "Authorization": f"Bearer {tok}", | |
| "Accept": "application/json", | |
| "Accept-Encoding": "identity", # <--- | |
| }, | |
| json={ | |
| "model": settings.LLM_MODEL_ID, | |
| "messages": [{"role":"user","content":[{"type":"text","text":"ping"}]}], | |
| "max_tokens": 1, | |
| "stream": False | |
| }, | |
| timeout=20 | |
| ) | |
| ok = r.ok | |
| # If OK, don't read text (keeps it light) | |
| body = "ok" if ok else (r.text[:200] if r.text else str(r.content[:200])) | |
| return {"status": r.status_code, "ok": ok, "body": body, "ce": r.headers.get("content-encoding")} | |
| except requests.exceptions.ContentDecodingError as e: | |
| # Rare gzip mismatch – report clearly | |
| return {"status": 502, "ok": False, "body": f"gzip decode error: {e.__class__.__name__}"} | |
| except Exception as e: | |
| return {"status": 500, "ok": False, "body": f"{type(e).__name__}: {e}"} | |