Spaces:
Sleeping
Sleeping
File size: 9,320 Bytes
57e072f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 | """
app/routes/benchmarks.py
Accuracy benchmarking system.
Doc says: "Your accuracy is unknowable. Without ground-truth labels
and systematic accuracy testing, you have no idea if your health scores
are correct. This is dangerous for a health product."
This module fixes that:
- Store ground-truth nutrition data for test products
- Run scanner against them
- Measure F1, field accuracy, score delta
- Publish results (builds trust + validates claims)
"""
import json
import logging
import asyncio
from fastapi import APIRouter, Request, HTTPException, Form, File, UploadFile
from fastapi.responses import JSONResponse
from app.models.db import db_conn
from app.services.image import validate_image, assess_image_quality, deblur_and_enhance, ocr_quality_score
from app.services.ocr import run_ocr, detect_label_presence
from app.services.llm import analyse_label, call_llm
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/benchmarks", tags=["benchmarks"])
def _compute_field_accuracy(llm_output: dict, ground_truth: dict) -> dict:
"""
Compare LLM-extracted nutrient values against hand-verified ground truth.
Returns per-field accuracy within a tolerance band.
"""
fields = ["calories", "protein", "carbs", "fat", "sodium", "fiber", "sugar"]
results = {}
exact_ct = 0
# Flatten LLM nutrient_breakdown into a dict
llm_nutr = {}
for n in llm_output.get("nutrient_breakdown", []):
key = n.get("name", "").lower()
val = n.get("value", 0)
if isinstance(val, (int, float)):
llm_nutr[key] = val
gt_nutr = ground_truth.get("nutrients", {})
for field in fields:
# Try to find field in LLM output (flexible naming)
llm_val = None
for k, v in llm_nutr.items():
if field in k or k in field:
llm_val = v
break
gt_val = gt_nutr.get(field)
if gt_val is None or llm_val is None:
results[field] = {"status": "missing", "llm": llm_val, "truth": gt_val}
continue
# Tolerance: within 15% or 2 units (whichever is larger)
tolerance = max(abs(gt_val) * 0.15, 2)
correct = abs(llm_val - gt_val) <= tolerance
if correct:
exact_ct += 1
results[field] = {
"status" : "correct" if correct else "wrong",
"llm" : llm_val,
"truth" : gt_val,
"delta" : round(llm_val - gt_val, 2),
"pct_err": round(abs(llm_val - gt_val) / max(gt_val, 1) * 100, 1),
}
# Score accuracy
gt_score = ground_truth.get("score")
llm_score = llm_output.get("score")
if gt_score is not None and llm_score is not None:
score_delta = abs(llm_score - gt_score)
results["score"] = {
"status" : "correct" if score_delta <= 1 else "wrong",
"llm" : llm_score,
"truth" : gt_score,
"delta" : llm_score - gt_score,
}
accuracy_pct = round(exact_ct / len(fields) * 100, 1)
return {"fields": results, "field_accuracy_pct": accuracy_pct}
def _word_f1(pred: str, truth: str) -> float:
if not truth:
return 0.0
pw = set(pred.lower().split())
tw = set(truth.lower().split())
tp = len(pw & tw)
pr = tp / len(pw) if pw else 0
rc = tp / len(tw) if tw else 0
return round(2 * pr * rc / (pr + rc), 3) if (pr + rc) else 0.0
@router.post("/submit-ground-truth")
async def submit_ground_truth(
request : Request,
product_name: str = Form(...),
admin_token : str = Form(...),
nutrients : str = Form(...), # JSON: {"calories":250,"protein":8,...}
score : int = Form(...), # hand-assigned Eatlytic score
ingredients : str = Form(""),
barcode : str = Form(""),
):
"""
Admin: register a product's hand-verified nutrition data as ground truth.
Run this for 100+ products to get meaningful accuracy benchmarks.
"""
import os
if admin_token != os.environ.get("ADMIN_TOKEN", "changeme"):
raise HTTPException(status_code=403, detail="Invalid admin token")
try:
gt = json.loads(nutrients)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="nutrients must be valid JSON")
ground_truth = {"nutrients": gt, "score": score, "ingredients": ingredients}
with db_conn() as conn:
conn.execute(
"""INSERT INTO benchmarks(product_name, ground_truth_json)
VALUES(?,?)""",
(product_name, json.dumps(ground_truth))
)
return JSONResponse({"registered": True, "product": product_name,
"message": "Ground truth saved. Run /benchmarks/run to test."})
@router.post("/run/{benchmark_id}")
async def run_benchmark(
request : Request,
benchmark_id : int,
image : UploadFile = File(...),
admin_token : str = Form(...),
):
"""
Run the scanner against a benchmark product and store accuracy metrics.
"""
import os
if admin_token != os.environ.get("ADMIN_TOKEN", "changeme"):
raise HTTPException(status_code=403, detail="Invalid admin token")
with db_conn() as conn:
bm_row = conn.execute(
"SELECT * FROM benchmarks WHERE id=?", (benchmark_id,)
).fetchone()
if not bm_row:
raise HTTPException(status_code=404, detail="Benchmark not found")
ground_truth = json.loads(bm_row["ground_truth_json"])
content = await image.read()
content = validate_image(content)
quality = assess_image_quality(content)
# Run through full pipeline
working = content
if quality["is_blurry"]:
try:
enhanced, _ = deblur_and_enhance(content, quality["blur_severity"])
if ocr_quality_score(run_ocr(enhanced, "en")) >= ocr_quality_score(run_ocr(content, "en")) * 0.85:
working = enhanced
except Exception:
pass
ocr_result = run_ocr(working, "en")
extracted_text = ocr_result["text"]
ocr_f1 = _word_f1(extracted_text,
ground_truth.get("ingredients", ""))
blur_info = {"detected": quality["is_blurry"], "severity": quality["blur_severity"],
"score": quality["blur_score"], "deblurred": working != content}
llm_output = await analyse_label(
extracted_text, "General Adult", "adult", "general",
"en", "", blur_info, "high"
)
field_acc = _compute_field_accuracy(llm_output, ground_truth)
with db_conn() as conn:
import os
conn.execute(
"""UPDATE benchmarks
SET ocr_text=?, llm_output_json=?, f1_score=?,
score_delta=?, field_accuracy=?, tested_at=datetime('now'),
model_used='llama-3.3-70b'
WHERE id=?""",
(extracted_text,
json.dumps(llm_output),
ocr_f1,
llm_output.get("score", 0) - ground_truth.get("score", 0),
json.dumps(field_acc),
benchmark_id)
)
return JSONResponse({
"benchmark_id" : benchmark_id,
"product_name" : bm_row["product_name"],
"ocr_f1" : ocr_f1,
"score_predicted" : llm_output.get("score"),
"score_truth" : ground_truth.get("score"),
"score_delta" : llm_output.get("score", 0) - ground_truth.get("score", 0),
"field_accuracy_pct": field_acc["field_accuracy_pct"],
"fields" : field_acc["fields"],
})
@router.get("/report")
async def accuracy_report(request: Request):
"""
Aggregate accuracy report across all benchmarks.
Publish this to build trust + validate claims.
"""
with db_conn() as conn:
rows = conn.execute(
"""SELECT product_name, f1_score, score_delta, field_accuracy, tested_at
FROM benchmarks WHERE f1_score > 0 ORDER BY tested_at DESC"""
).fetchall()
if not rows:
return JSONResponse({
"message" : "No benchmarks run yet.",
"action" : "POST /benchmarks/submit-ground-truth to register products, then POST /benchmarks/run/{id}",
"products_tested": 0,
})
f1_scores = [r["f1_score"] for r in rows if r["f1_score"]]
score_deltas = [abs(r["score_delta"]) for r in rows if r["score_delta"] is not None]
field_accs = []
for r in rows:
try:
fa = json.loads(r["field_accuracy"] or "{}")
pct = fa.get("field_accuracy_pct")
if pct is not None:
field_accs.append(pct)
except Exception:
pass
return JSONResponse({
"products_tested" : len(rows),
"avg_ocr_f1" : round(sum(f1_scores) / len(f1_scores), 3) if f1_scores else 0,
"avg_score_delta" : round(sum(score_deltas) / len(score_deltas), 2) if score_deltas else 0,
"avg_field_accuracy": f"{round(sum(field_accs)/len(field_accs), 1)}%" if field_accs else "N/A",
"results" : [
{"product": r["product_name"], "ocr_f1": r["f1_score"],
"score_delta": r["score_delta"], "tested_at": r["tested_at"]}
for r in rows
],
})
|