e / benchmarks.py
Shaikhsarib's picture
Upload 11 files
57e072f verified
"""
app/routes/benchmarks.py
Accuracy benchmarking system.
Doc says: "Your accuracy is unknowable. Without ground-truth labels
and systematic accuracy testing, you have no idea if your health scores
are correct. This is dangerous for a health product."
This module fixes that:
- Store ground-truth nutrition data for test products
- Run scanner against them
- Measure F1, field accuracy, score delta
- Publish results (builds trust + validates claims)
"""
import json
import logging
import asyncio
from fastapi import APIRouter, Request, HTTPException, Form, File, UploadFile
from fastapi.responses import JSONResponse
from app.models.db import db_conn
from app.services.image import validate_image, assess_image_quality, deblur_and_enhance, ocr_quality_score
from app.services.ocr import run_ocr, detect_label_presence
from app.services.llm import analyse_label, call_llm
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/benchmarks", tags=["benchmarks"])
def _compute_field_accuracy(llm_output: dict, ground_truth: dict) -> dict:
"""
Compare LLM-extracted nutrient values against hand-verified ground truth.
Returns per-field accuracy within a tolerance band.
"""
fields = ["calories", "protein", "carbs", "fat", "sodium", "fiber", "sugar"]
results = {}
exact_ct = 0
# Flatten LLM nutrient_breakdown into a dict
llm_nutr = {}
for n in llm_output.get("nutrient_breakdown", []):
key = n.get("name", "").lower()
val = n.get("value", 0)
if isinstance(val, (int, float)):
llm_nutr[key] = val
gt_nutr = ground_truth.get("nutrients", {})
for field in fields:
# Try to find field in LLM output (flexible naming)
llm_val = None
for k, v in llm_nutr.items():
if field in k or k in field:
llm_val = v
break
gt_val = gt_nutr.get(field)
if gt_val is None or llm_val is None:
results[field] = {"status": "missing", "llm": llm_val, "truth": gt_val}
continue
# Tolerance: within 15% or 2 units (whichever is larger)
tolerance = max(abs(gt_val) * 0.15, 2)
correct = abs(llm_val - gt_val) <= tolerance
if correct:
exact_ct += 1
results[field] = {
"status" : "correct" if correct else "wrong",
"llm" : llm_val,
"truth" : gt_val,
"delta" : round(llm_val - gt_val, 2),
"pct_err": round(abs(llm_val - gt_val) / max(gt_val, 1) * 100, 1),
}
# Score accuracy
gt_score = ground_truth.get("score")
llm_score = llm_output.get("score")
if gt_score is not None and llm_score is not None:
score_delta = abs(llm_score - gt_score)
results["score"] = {
"status" : "correct" if score_delta <= 1 else "wrong",
"llm" : llm_score,
"truth" : gt_score,
"delta" : llm_score - gt_score,
}
accuracy_pct = round(exact_ct / len(fields) * 100, 1)
return {"fields": results, "field_accuracy_pct": accuracy_pct}
def _word_f1(pred: str, truth: str) -> float:
if not truth:
return 0.0
pw = set(pred.lower().split())
tw = set(truth.lower().split())
tp = len(pw & tw)
pr = tp / len(pw) if pw else 0
rc = tp / len(tw) if tw else 0
return round(2 * pr * rc / (pr + rc), 3) if (pr + rc) else 0.0
@router.post("/submit-ground-truth")
async def submit_ground_truth(
request : Request,
product_name: str = Form(...),
admin_token : str = Form(...),
nutrients : str = Form(...), # JSON: {"calories":250,"protein":8,...}
score : int = Form(...), # hand-assigned Eatlytic score
ingredients : str = Form(""),
barcode : str = Form(""),
):
"""
Admin: register a product's hand-verified nutrition data as ground truth.
Run this for 100+ products to get meaningful accuracy benchmarks.
"""
import os
if admin_token != os.environ.get("ADMIN_TOKEN", "changeme"):
raise HTTPException(status_code=403, detail="Invalid admin token")
try:
gt = json.loads(nutrients)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="nutrients must be valid JSON")
ground_truth = {"nutrients": gt, "score": score, "ingredients": ingredients}
with db_conn() as conn:
conn.execute(
"""INSERT INTO benchmarks(product_name, ground_truth_json)
VALUES(?,?)""",
(product_name, json.dumps(ground_truth))
)
return JSONResponse({"registered": True, "product": product_name,
"message": "Ground truth saved. Run /benchmarks/run to test."})
@router.post("/run/{benchmark_id}")
async def run_benchmark(
request : Request,
benchmark_id : int,
image : UploadFile = File(...),
admin_token : str = Form(...),
):
"""
Run the scanner against a benchmark product and store accuracy metrics.
"""
import os
if admin_token != os.environ.get("ADMIN_TOKEN", "changeme"):
raise HTTPException(status_code=403, detail="Invalid admin token")
with db_conn() as conn:
bm_row = conn.execute(
"SELECT * FROM benchmarks WHERE id=?", (benchmark_id,)
).fetchone()
if not bm_row:
raise HTTPException(status_code=404, detail="Benchmark not found")
ground_truth = json.loads(bm_row["ground_truth_json"])
content = await image.read()
content = validate_image(content)
quality = assess_image_quality(content)
# Run through full pipeline
working = content
if quality["is_blurry"]:
try:
enhanced, _ = deblur_and_enhance(content, quality["blur_severity"])
if ocr_quality_score(run_ocr(enhanced, "en")) >= ocr_quality_score(run_ocr(content, "en")) * 0.85:
working = enhanced
except Exception:
pass
ocr_result = run_ocr(working, "en")
extracted_text = ocr_result["text"]
ocr_f1 = _word_f1(extracted_text,
ground_truth.get("ingredients", ""))
blur_info = {"detected": quality["is_blurry"], "severity": quality["blur_severity"],
"score": quality["blur_score"], "deblurred": working != content}
llm_output = await analyse_label(
extracted_text, "General Adult", "adult", "general",
"en", "", blur_info, "high"
)
field_acc = _compute_field_accuracy(llm_output, ground_truth)
with db_conn() as conn:
import os
conn.execute(
"""UPDATE benchmarks
SET ocr_text=?, llm_output_json=?, f1_score=?,
score_delta=?, field_accuracy=?, tested_at=datetime('now'),
model_used='llama-3.3-70b'
WHERE id=?""",
(extracted_text,
json.dumps(llm_output),
ocr_f1,
llm_output.get("score", 0) - ground_truth.get("score", 0),
json.dumps(field_acc),
benchmark_id)
)
return JSONResponse({
"benchmark_id" : benchmark_id,
"product_name" : bm_row["product_name"],
"ocr_f1" : ocr_f1,
"score_predicted" : llm_output.get("score"),
"score_truth" : ground_truth.get("score"),
"score_delta" : llm_output.get("score", 0) - ground_truth.get("score", 0),
"field_accuracy_pct": field_acc["field_accuracy_pct"],
"fields" : field_acc["fields"],
})
@router.get("/report")
async def accuracy_report(request: Request):
"""
Aggregate accuracy report across all benchmarks.
Publish this to build trust + validate claims.
"""
with db_conn() as conn:
rows = conn.execute(
"""SELECT product_name, f1_score, score_delta, field_accuracy, tested_at
FROM benchmarks WHERE f1_score > 0 ORDER BY tested_at DESC"""
).fetchall()
if not rows:
return JSONResponse({
"message" : "No benchmarks run yet.",
"action" : "POST /benchmarks/submit-ground-truth to register products, then POST /benchmarks/run/{id}",
"products_tested": 0,
})
f1_scores = [r["f1_score"] for r in rows if r["f1_score"]]
score_deltas = [abs(r["score_delta"]) for r in rows if r["score_delta"] is not None]
field_accs = []
for r in rows:
try:
fa = json.loads(r["field_accuracy"] or "{}")
pct = fa.get("field_accuracy_pct")
if pct is not None:
field_accs.append(pct)
except Exception:
pass
return JSONResponse({
"products_tested" : len(rows),
"avg_ocr_f1" : round(sum(f1_scores) / len(f1_scores), 3) if f1_scores else 0,
"avg_score_delta" : round(sum(score_deltas) / len(score_deltas), 2) if score_deltas else 0,
"avg_field_accuracy": f"{round(sum(field_accs)/len(field_accs), 1)}%" if field_accs else "N/A",
"results" : [
{"product": r["product_name"], "ocr_f1": r["f1_score"],
"score_delta": r["score_delta"], "tested_at": r["tested_at"]}
for r in rows
],
})