e
File size: 9,320 Bytes
57e072f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
"""
app/routes/benchmarks.py
Accuracy benchmarking system.

Doc says: "Your accuracy is unknowable. Without ground-truth labels
and systematic accuracy testing, you have no idea if your health scores
are correct. This is dangerous for a health product."

This module fixes that:
- Store ground-truth nutrition data for test products
- Run scanner against them
- Measure F1, field accuracy, score delta
- Publish results (builds trust + validates claims)
"""
import json
import logging
import asyncio
from fastapi import APIRouter, Request, HTTPException, Form, File, UploadFile
from fastapi.responses import JSONResponse
from app.models.db import db_conn
from app.services.image import validate_image, assess_image_quality, deblur_and_enhance, ocr_quality_score
from app.services.ocr import run_ocr, detect_label_presence
from app.services.llm import analyse_label, call_llm

logger = logging.getLogger(__name__)
router = APIRouter(prefix="/benchmarks", tags=["benchmarks"])


def _compute_field_accuracy(llm_output: dict, ground_truth: dict) -> dict:
    """
    Compare LLM-extracted nutrient values against hand-verified ground truth.
    Returns per-field accuracy within a tolerance band.
    """
    fields    = ["calories", "protein", "carbs", "fat", "sodium", "fiber", "sugar"]
    results   = {}
    exact_ct  = 0

    # Flatten LLM nutrient_breakdown into a dict
    llm_nutr = {}
    for n in llm_output.get("nutrient_breakdown", []):
        key = n.get("name", "").lower()
        val = n.get("value", 0)
        if isinstance(val, (int, float)):
            llm_nutr[key] = val

    gt_nutr = ground_truth.get("nutrients", {})

    for field in fields:
        # Try to find field in LLM output (flexible naming)
        llm_val = None
        for k, v in llm_nutr.items():
            if field in k or k in field:
                llm_val = v
                break

        gt_val = gt_nutr.get(field)
        if gt_val is None or llm_val is None:
            results[field] = {"status": "missing", "llm": llm_val, "truth": gt_val}
            continue

        # Tolerance: within 15% or 2 units (whichever is larger)
        tolerance = max(abs(gt_val) * 0.15, 2)
        correct   = abs(llm_val - gt_val) <= tolerance
        if correct:
            exact_ct += 1
        results[field] = {
            "status" : "correct" if correct else "wrong",
            "llm"    : llm_val,
            "truth"  : gt_val,
            "delta"  : round(llm_val - gt_val, 2),
            "pct_err": round(abs(llm_val - gt_val) / max(gt_val, 1) * 100, 1),
        }

    # Score accuracy
    gt_score  = ground_truth.get("score")
    llm_score = llm_output.get("score")
    if gt_score is not None and llm_score is not None:
        score_delta = abs(llm_score - gt_score)
        results["score"] = {
            "status" : "correct" if score_delta <= 1 else "wrong",
            "llm"    : llm_score,
            "truth"  : gt_score,
            "delta"  : llm_score - gt_score,
        }

    accuracy_pct = round(exact_ct / len(fields) * 100, 1)
    return {"fields": results, "field_accuracy_pct": accuracy_pct}


def _word_f1(pred: str, truth: str) -> float:
    if not truth:
        return 0.0
    pw = set(pred.lower().split())
    tw = set(truth.lower().split())
    tp = len(pw & tw)
    pr = tp / len(pw) if pw else 0
    rc = tp / len(tw) if tw else 0
    return round(2 * pr * rc / (pr + rc), 3) if (pr + rc) else 0.0


@router.post("/submit-ground-truth")
async def submit_ground_truth(
    request     : Request,
    product_name: str  = Form(...),
    admin_token : str  = Form(...),
    nutrients   : str  = Form(...),  # JSON: {"calories":250,"protein":8,...}
    score       : int  = Form(...),  # hand-assigned Eatlytic score
    ingredients : str  = Form(""),
    barcode     : str  = Form(""),
):
    """
    Admin: register a product's hand-verified nutrition data as ground truth.
    Run this for 100+ products to get meaningful accuracy benchmarks.
    """
    import os
    if admin_token != os.environ.get("ADMIN_TOKEN", "changeme"):
        raise HTTPException(status_code=403, detail="Invalid admin token")

    try:
        gt = json.loads(nutrients)
    except json.JSONDecodeError:
        raise HTTPException(status_code=400, detail="nutrients must be valid JSON")

    ground_truth = {"nutrients": gt, "score": score, "ingredients": ingredients}

    with db_conn() as conn:
        conn.execute(
            """INSERT INTO benchmarks(product_name, ground_truth_json)
               VALUES(?,?)""",
            (product_name, json.dumps(ground_truth))
        )
    return JSONResponse({"registered": True, "product": product_name,
                         "message": "Ground truth saved. Run /benchmarks/run to test."})


@router.post("/run/{benchmark_id}")
async def run_benchmark(
    request      : Request,
    benchmark_id : int,
    image        : UploadFile = File(...),
    admin_token  : str = Form(...),
):
    """
    Run the scanner against a benchmark product and store accuracy metrics.
    """
    import os
    if admin_token != os.environ.get("ADMIN_TOKEN", "changeme"):
        raise HTTPException(status_code=403, detail="Invalid admin token")

    with db_conn() as conn:
        bm_row = conn.execute(
            "SELECT * FROM benchmarks WHERE id=?", (benchmark_id,)
        ).fetchone()
    if not bm_row:
        raise HTTPException(status_code=404, detail="Benchmark not found")

    ground_truth = json.loads(bm_row["ground_truth_json"])
    content      = await image.read()
    content      = validate_image(content)
    quality      = assess_image_quality(content)

    # Run through full pipeline
    working = content
    if quality["is_blurry"]:
        try:
            enhanced, _ = deblur_and_enhance(content, quality["blur_severity"])
            if ocr_quality_score(run_ocr(enhanced, "en")) >= ocr_quality_score(run_ocr(content, "en")) * 0.85:
                working = enhanced
        except Exception:
            pass

    ocr_result     = run_ocr(working, "en")
    extracted_text = ocr_result["text"]
    ocr_f1         = _word_f1(extracted_text,
                               ground_truth.get("ingredients", ""))

    blur_info = {"detected": quality["is_blurry"], "severity": quality["blur_severity"],
                 "score": quality["blur_score"], "deblurred": working != content}

    llm_output = await analyse_label(
        extracted_text, "General Adult", "adult", "general",
        "en", "", blur_info, "high"
    )

    field_acc = _compute_field_accuracy(llm_output, ground_truth)

    with db_conn() as conn:
        import os
        conn.execute(
            """UPDATE benchmarks
               SET ocr_text=?, llm_output_json=?, f1_score=?,
                   score_delta=?, field_accuracy=?, tested_at=datetime('now'),
                   model_used='llama-3.3-70b'
               WHERE id=?""",
            (extracted_text,
             json.dumps(llm_output),
             ocr_f1,
             llm_output.get("score", 0) - ground_truth.get("score", 0),
             json.dumps(field_acc),
             benchmark_id)
        )

    return JSONResponse({
        "benchmark_id"      : benchmark_id,
        "product_name"      : bm_row["product_name"],
        "ocr_f1"            : ocr_f1,
        "score_predicted"   : llm_output.get("score"),
        "score_truth"       : ground_truth.get("score"),
        "score_delta"       : llm_output.get("score", 0) - ground_truth.get("score", 0),
        "field_accuracy_pct": field_acc["field_accuracy_pct"],
        "fields"            : field_acc["fields"],
    })


@router.get("/report")
async def accuracy_report(request: Request):
    """
    Aggregate accuracy report across all benchmarks.
    Publish this to build trust + validate claims.
    """
    with db_conn() as conn:
        rows = conn.execute(
            """SELECT product_name, f1_score, score_delta, field_accuracy, tested_at
               FROM benchmarks WHERE f1_score > 0 ORDER BY tested_at DESC"""
        ).fetchall()

    if not rows:
        return JSONResponse({
            "message" : "No benchmarks run yet.",
            "action"  : "POST /benchmarks/submit-ground-truth to register products, then POST /benchmarks/run/{id}",
            "products_tested": 0,
        })

    f1_scores    = [r["f1_score"] for r in rows if r["f1_score"]]
    score_deltas = [abs(r["score_delta"]) for r in rows if r["score_delta"] is not None]

    field_accs = []
    for r in rows:
        try:
            fa = json.loads(r["field_accuracy"] or "{}")
            pct = fa.get("field_accuracy_pct")
            if pct is not None:
                field_accs.append(pct)
        except Exception:
            pass

    return JSONResponse({
        "products_tested"   : len(rows),
        "avg_ocr_f1"        : round(sum(f1_scores) / len(f1_scores), 3) if f1_scores else 0,
        "avg_score_delta"   : round(sum(score_deltas) / len(score_deltas), 2) if score_deltas else 0,
        "avg_field_accuracy": f"{round(sum(field_accs)/len(field_accs), 1)}%" if field_accs else "N/A",
        "results"           : [
            {"product": r["product_name"], "ocr_f1": r["f1_score"],
             "score_delta": r["score_delta"], "tested_at": r["tested_at"]}
            for r in rows
        ],
    })