File size: 3,602 Bytes
cbc82e5
 
 
 
 
 
 
 
 
 
 
 
 
 
6cf9641
 
cbc82e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cf9641
cbc82e5
 
 
 
 
 
 
 
 
6cf9641
cbc82e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cf9641
 
 
 
cbc82e5
 
 
6cf9641
cbc82e5
 
 
6cf9641
cbc82e5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""Generate bundled example assessments with the configured model endpoint."""

from __future__ import annotations

import base64
import json
import mimetypes
import sys
from datetime import date
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(ROOT))

from app.config import EXAMPLE_CACHE_PATH, model_config  # noqa: E402
from app.model_endpoint import call_model  # noqa: E402

TEXT_EXAMPLES = {
    "text-courier": (
        "PAKISTAN POST: Your parcel address is incomplete. Pay Rs. 85 today at "
        "http://pakpost-delivery.xyz or the parcel will be destroyed."
    ),
    "text-fbr": (
        "FBR REFUND: You are eligible for Rs 42,500. Submit your CNIC and bank "
        "card details at the link today to receive payment."
    ),
    "text-bank": (
        "HBL Security: Your account will be suspended. Share the OTP sent to "
        "your phone with our support team immediately."
    ),
}

IMAGE_EXAMPLES = {
    "image-courier": ROOT / "static" / "example-courier.jpeg",
    "image-mobile": ROOT / "static" / "example-mobile.png",
    "image-traffic": ROOT / "static" / "example-trafic.png",
}


def image_data_url(path: Path) -> str:
    mime_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream"
    encoded = base64.b64encode(path.read_bytes()).decode("ascii")
    return f"data:{mime_type};base64,{encoded}"


def quality_issue(example_id: str, assessment: dict[str, object]) -> str:
    explanation = str(assessment["simple_explanation"]).lower()
    next_steps = " ".join(
        str(item) for item in assessment["safe_next_steps"]  # type: ignore[union-attr]
    ).lower()
    if "social media" in next_steps:
        return "safe next steps recommend social media"
    if any(phrase in explanation for phrase in ("in the future", "in the past")):
        return "explanation makes an unsupported date comparison"
    if example_id == "image-traffic" and any(
        name in next_steps for name in ("fbr", "nadra")
    ):
        return "traffic fine advice names an unrelated authority"
    return ""


def generate_assessment(
    example_id: str,
    *,
    text: str = "",
    image: str = "",
) -> dict[str, object]:
    last_issue = ""
    for attempt in range(1, 4):
        assessment = call_model(text, image)
        last_issue = quality_issue(example_id, assessment)
        if not last_issue:
            print(f"{example_id}: accepted on attempt {attempt}")
            return assessment
        print(f"{example_id}: retrying after attempt {attempt}: {last_issue}")
    raise RuntimeError(f"{example_id} failed cache quality checks: {last_issue}")


def main() -> None:
    config = model_config()
    examples = {
        example_id: generate_assessment(example_id, text=text)
        for example_id, text in TEXT_EXAMPLES.items()
    }
    examples.update(
        {
            example_id: generate_assessment(
                example_id,
                image=image_data_url(path),
            )
            for example_id, path in IMAGE_EXAMPLES.items()
        }
    )

    document = {
        "model_repo": config.repo_id,
        "model_name": config.filename,
        "endpoint": config.source,
        "endpoint_type": "In-process llama.cpp runtime",
        "generated_at": date.today().isoformat(),
        "examples": examples,
    }
    EXAMPLE_CACHE_PATH.write_text(
        json.dumps(document, indent=2, ensure_ascii=True) + "\n",
        encoding="utf-8",
    )
    print(f"Generated {len(examples)} assessments in {EXAMPLE_CACHE_PATH}")


if __name__ == "__main__":
    main()