import base64 import json import os from pathlib import Path from typing import Dict, List, Tuple from dotenv import load_dotenv from openai import OpenAI from transformers import pipeline from labels import load_labels load_dotenv() PROJECT_DIR = Path(__file__).resolve().parent class ModelComparison: def __init__(self) -> None: self.labels = load_labels() self.openai_model = os.getenv("OPENAI_MODEL", "gpt-4.1-mini") self.openai_client = self._create_openai_client() self.custom_classifier, self.custom_model_name = self._load_custom_classifier() self.clip_detector = pipeline( task="zero-shot-image-classification", model=os.getenv("CLIP_MODEL_ID", "openai/clip-vit-large-patch14"), ) @staticmethod def _create_openai_client() -> OpenAI | None: api_key = os.getenv("OPENAI_API_KEY") return OpenAI(api_key=api_key) if api_key else None @staticmethod def _encode_image(image_path: str) -> str: with open(image_path, "rb") as image_file: return base64.b64encode(image_file.read()).decode("utf-8") def _load_custom_classifier(self): local_model_dir = PROJECT_DIR / "models" / "custom-vit-model" candidates = [ os.getenv("HF_MODEL_ID"), str(local_model_dir), "google/vit-base-patch16-224", ] for candidate in candidates: if not candidate: continue if candidate == str(local_model_dir) and not local_model_dir.exists(): continue try: clf = pipeline(task="image-classification", model=candidate) return clf, candidate except Exception: continue # Last fallback should always exist in practice. fallback = "google/vit-base-patch16-224" return pipeline(task="image-classification", model=fallback), fallback @staticmethod def _to_top_k_dict(results: List[Dict], k: int = 3) -> Dict[str, float]: return { result["label"]: round(float(result["score"]), 4) for result in results[:k] } def classify_with_openai(self, image_path: str) -> Dict: if self.openai_client is None: return { "error": "OPENAI_API_KEY is missing. Add it as an environment variable or HF Space secret.", } prompt = ( "Classify the image into one label from the following list: " f"{', '.join(self.labels)}. " "Return valid JSON with exactly these keys: label, confidence, reasoning. " "confidence must be a numeric value between 0 and 1." ) base64_image = self._encode_image(image_path) response = self.openai_client.responses.create( model=self.openai_model, input=[ { "role": "user", "content": [ {"type": "input_text", "text": prompt}, { "type": "input_image", "image_url": f"data:image/jpeg;base64,{base64_image}", }, ], } ], ) try: parsed = json.loads(response.output_text) except json.JSONDecodeError: parsed = { "label": "unknown", "confidence": 0.0, "reasoning": response.output_text, "warning": "OpenAI response was not valid JSON.", } return parsed def classify_all(self, image_path: str) -> Dict: custom_results = self.custom_classifier(image_path) clip_results = self.clip_detector(image_path, candidate_labels=self.labels) openai_results = self.classify_with_openai(image_path) return { "Custom Transfer Learning Model": { "model": self.custom_model_name, "top_3": self._to_top_k_dict(custom_results, k=3), }, "Open-Source Zero-Shot (CLIP)": { "model": os.getenv("CLIP_MODEL_ID", "openai/clip-vit-large-patch14"), "top_3": self._to_top_k_dict(clip_results, k=3), }, "Closed-Source Vision Model (OpenAI)": { "model": self.openai_model, "prediction": openai_results, }, } def discover_example_images(example_dir: str = "example_images") -> List[List[str]]: path = Path(example_dir) if not path.is_absolute(): path = PROJECT_DIR / path if not path.exists(): return [] images = sorted( [ p for p in path.iterdir() if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp"} ] ) return [[str(p)] for p in images] def classify_for_table(comparison: ModelComparison, image_path: str) -> Tuple[str, str, str]: result = comparison.classify_all(image_path) custom_top = result["Custom Transfer Learning Model"]["top_3"] clip_top = result["Open-Source Zero-Shot (CLIP)"]["top_3"] openai_pred = result["Closed-Source Vision Model (OpenAI)"]["prediction"] custom_str = "; ".join([f"{k}: {v}" for k, v in custom_top.items()]) clip_str = "; ".join([f"{k}: {v}" for k, v in clip_top.items()]) if isinstance(openai_pred, dict): openai_label = str(openai_pred.get("label", "unknown")) openai_conf = str(openai_pred.get("confidence", "n/a")) openai_str = f"{openai_label}: {openai_conf}" else: openai_str = str(openai_pred) return custom_str, clip_str, openai_str