Abgabe2 / model_comparison.py
nbacchi's picture
Upload 6 files
16c7630 verified
import base64
import json
import os
from pathlib import Path
from typing import Dict, List, Tuple
from dotenv import load_dotenv
from openai import OpenAI
from transformers import pipeline
from labels import load_labels
load_dotenv()
PROJECT_DIR = Path(__file__).resolve().parent
class ModelComparison:
def __init__(self) -> None:
self.labels = load_labels()
self.openai_model = os.getenv("OPENAI_MODEL", "gpt-4.1-mini")
self.openai_client = self._create_openai_client()
self.custom_classifier, self.custom_model_name = self._load_custom_classifier()
self.clip_detector = pipeline(
task="zero-shot-image-classification",
model=os.getenv("CLIP_MODEL_ID", "openai/clip-vit-large-patch14"),
)
@staticmethod
def _create_openai_client() -> OpenAI | None:
api_key = os.getenv("OPENAI_API_KEY")
return OpenAI(api_key=api_key) if api_key else None
@staticmethod
def _encode_image(image_path: str) -> str:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def _load_custom_classifier(self):
local_model_dir = PROJECT_DIR / "models" / "custom-vit-model"
candidates = [
os.getenv("HF_MODEL_ID"),
str(local_model_dir),
"google/vit-base-patch16-224",
]
for candidate in candidates:
if not candidate:
continue
if candidate == str(local_model_dir) and not local_model_dir.exists():
continue
try:
clf = pipeline(task="image-classification", model=candidate)
return clf, candidate
except Exception:
continue
# Last fallback should always exist in practice.
fallback = "google/vit-base-patch16-224"
return pipeline(task="image-classification", model=fallback), fallback
@staticmethod
def _to_top_k_dict(results: List[Dict], k: int = 3) -> Dict[str, float]:
return {
result["label"]: round(float(result["score"]), 4)
for result in results[:k]
}
def classify_with_openai(self, image_path: str) -> Dict:
if self.openai_client is None:
return {
"error": "OPENAI_API_KEY is missing. Add it as an environment variable or HF Space secret.",
}
prompt = (
"Classify the image into one label from the following list: "
f"{', '.join(self.labels)}. "
"Return valid JSON with exactly these keys: label, confidence, reasoning. "
"confidence must be a numeric value between 0 and 1."
)
base64_image = self._encode_image(image_path)
response = self.openai_client.responses.create(
model=self.openai_model,
input=[
{
"role": "user",
"content": [
{"type": "input_text", "text": prompt},
{
"type": "input_image",
"image_url": f"data:image/jpeg;base64,{base64_image}",
},
],
}
],
)
try:
parsed = json.loads(response.output_text)
except json.JSONDecodeError:
parsed = {
"label": "unknown",
"confidence": 0.0,
"reasoning": response.output_text,
"warning": "OpenAI response was not valid JSON.",
}
return parsed
def classify_all(self, image_path: str) -> Dict:
custom_results = self.custom_classifier(image_path)
clip_results = self.clip_detector(image_path, candidate_labels=self.labels)
openai_results = self.classify_with_openai(image_path)
return {
"Custom Transfer Learning Model": {
"model": self.custom_model_name,
"top_3": self._to_top_k_dict(custom_results, k=3),
},
"Open-Source Zero-Shot (CLIP)": {
"model": os.getenv("CLIP_MODEL_ID", "openai/clip-vit-large-patch14"),
"top_3": self._to_top_k_dict(clip_results, k=3),
},
"Closed-Source Vision Model (OpenAI)": {
"model": self.openai_model,
"prediction": openai_results,
},
}
def discover_example_images(example_dir: str = "example_images") -> List[List[str]]:
path = Path(example_dir)
if not path.is_absolute():
path = PROJECT_DIR / path
if not path.exists():
return []
images = sorted(
[
p for p in path.iterdir()
if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp"}
]
)
return [[str(p)] for p in images]
def classify_for_table(comparison: ModelComparison, image_path: str) -> Tuple[str, str, str]:
result = comparison.classify_all(image_path)
custom_top = result["Custom Transfer Learning Model"]["top_3"]
clip_top = result["Open-Source Zero-Shot (CLIP)"]["top_3"]
openai_pred = result["Closed-Source Vision Model (OpenAI)"]["prediction"]
custom_str = "; ".join([f"{k}: {v}" for k, v in custom_top.items()])
clip_str = "; ".join([f"{k}: {v}" for k, v in clip_top.items()])
if isinstance(openai_pred, dict):
openai_label = str(openai_pred.get("label", "unknown"))
openai_conf = str(openai_pred.get("confidence", "n/a"))
openai_str = f"{openai_label}: {openai_conf}"
else:
openai_str = str(openai_pred)
return custom_str, clip_str, openai_str