File size: 5,729 Bytes
16c7630
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import base64
import json
import os
from pathlib import Path
from typing import Dict, List, Tuple

from dotenv import load_dotenv
from openai import OpenAI
from transformers import pipeline

from labels import load_labels

load_dotenv()

PROJECT_DIR = Path(__file__).resolve().parent


class ModelComparison:
    def __init__(self) -> None:
        self.labels = load_labels()
        self.openai_model = os.getenv("OPENAI_MODEL", "gpt-4.1-mini")
        self.openai_client = self._create_openai_client()
        self.custom_classifier, self.custom_model_name = self._load_custom_classifier()
        self.clip_detector = pipeline(
            task="zero-shot-image-classification",
            model=os.getenv("CLIP_MODEL_ID", "openai/clip-vit-large-patch14"),
        )

    @staticmethod
    def _create_openai_client() -> OpenAI | None:
        api_key = os.getenv("OPENAI_API_KEY")
        return OpenAI(api_key=api_key) if api_key else None

    @staticmethod
    def _encode_image(image_path: str) -> str:
        with open(image_path, "rb") as image_file:
            return base64.b64encode(image_file.read()).decode("utf-8")

    def _load_custom_classifier(self):
        local_model_dir = PROJECT_DIR / "models" / "custom-vit-model"
        candidates = [
            os.getenv("HF_MODEL_ID"),
            str(local_model_dir),
            "google/vit-base-patch16-224",
        ]

        for candidate in candidates:
            if not candidate:
                continue
            if candidate == str(local_model_dir) and not local_model_dir.exists():
                continue
            try:
                clf = pipeline(task="image-classification", model=candidate)
                return clf, candidate
            except Exception:
                continue

        # Last fallback should always exist in practice.
        fallback = "google/vit-base-patch16-224"
        return pipeline(task="image-classification", model=fallback), fallback

    @staticmethod
    def _to_top_k_dict(results: List[Dict], k: int = 3) -> Dict[str, float]:
        return {
            result["label"]: round(float(result["score"]), 4)
            for result in results[:k]
        }

    def classify_with_openai(self, image_path: str) -> Dict:
        if self.openai_client is None:
            return {
                "error": "OPENAI_API_KEY is missing. Add it as an environment variable or HF Space secret.",
            }

        prompt = (
            "Classify the image into one label from the following list: "
            f"{', '.join(self.labels)}. "
            "Return valid JSON with exactly these keys: label, confidence, reasoning. "
            "confidence must be a numeric value between 0 and 1."
        )

        base64_image = self._encode_image(image_path)
        response = self.openai_client.responses.create(
            model=self.openai_model,
            input=[
                {
                    "role": "user",
                    "content": [
                        {"type": "input_text", "text": prompt},
                        {
                            "type": "input_image",
                            "image_url": f"data:image/jpeg;base64,{base64_image}",
                        },
                    ],
                }
            ],
        )

        try:
            parsed = json.loads(response.output_text)
        except json.JSONDecodeError:
            parsed = {
                "label": "unknown",
                "confidence": 0.0,
                "reasoning": response.output_text,
                "warning": "OpenAI response was not valid JSON.",
            }

        return parsed

    def classify_all(self, image_path: str) -> Dict:
        custom_results = self.custom_classifier(image_path)
        clip_results = self.clip_detector(image_path, candidate_labels=self.labels)
        openai_results = self.classify_with_openai(image_path)

        return {
            "Custom Transfer Learning Model": {
                "model": self.custom_model_name,
                "top_3": self._to_top_k_dict(custom_results, k=3),
            },
            "Open-Source Zero-Shot (CLIP)": {
                "model": os.getenv("CLIP_MODEL_ID", "openai/clip-vit-large-patch14"),
                "top_3": self._to_top_k_dict(clip_results, k=3),
            },
            "Closed-Source Vision Model (OpenAI)": {
                "model": self.openai_model,
                "prediction": openai_results,
            },
        }


def discover_example_images(example_dir: str = "example_images") -> List[List[str]]:
    path = Path(example_dir)
    if not path.is_absolute():
        path = PROJECT_DIR / path

    if not path.exists():
        return []

    images = sorted(
        [
            p for p in path.iterdir()
            if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp"}
        ]
    )
    return [[str(p)] for p in images]


def classify_for_table(comparison: ModelComparison, image_path: str) -> Tuple[str, str, str]:
    result = comparison.classify_all(image_path)

    custom_top = result["Custom Transfer Learning Model"]["top_3"]
    clip_top = result["Open-Source Zero-Shot (CLIP)"]["top_3"]
    openai_pred = result["Closed-Source Vision Model (OpenAI)"]["prediction"]

    custom_str = "; ".join([f"{k}: {v}" for k, v in custom_top.items()])
    clip_str = "; ".join([f"{k}: {v}" for k, v in clip_top.items()])

    if isinstance(openai_pred, dict):
        openai_label = str(openai_pred.get("label", "unknown"))
        openai_conf = str(openai_pred.get("confidence", "n/a"))
        openai_str = f"{openai_label}: {openai_conf}"
    else:
        openai_str = str(openai_pred)

    return custom_str, clip_str, openai_str