| import base64 |
| import json |
| import os |
| from pathlib import Path |
| from typing import Dict, List, Tuple |
|
|
| from dotenv import load_dotenv |
| from openai import OpenAI |
| from transformers import pipeline |
|
|
| from labels import load_labels |
|
|
| load_dotenv() |
|
|
| PROJECT_DIR = Path(__file__).resolve().parent |
|
|
|
|
| class ModelComparison: |
| def __init__(self) -> None: |
| self.labels = load_labels() |
| self.openai_model = os.getenv("OPENAI_MODEL", "gpt-4.1-mini") |
| self.openai_client = self._create_openai_client() |
| self.custom_classifier, self.custom_model_name = self._load_custom_classifier() |
| self.clip_detector = pipeline( |
| task="zero-shot-image-classification", |
| model=os.getenv("CLIP_MODEL_ID", "openai/clip-vit-large-patch14"), |
| ) |
|
|
| @staticmethod |
| def _create_openai_client() -> OpenAI | None: |
| api_key = os.getenv("OPENAI_API_KEY") |
| return OpenAI(api_key=api_key) if api_key else None |
|
|
| @staticmethod |
| def _encode_image(image_path: str) -> str: |
| with open(image_path, "rb") as image_file: |
| return base64.b64encode(image_file.read()).decode("utf-8") |
|
|
| def _load_custom_classifier(self): |
| local_model_dir = PROJECT_DIR / "models" / "custom-vit-model" |
| candidates = [ |
| os.getenv("HF_MODEL_ID"), |
| str(local_model_dir), |
| "google/vit-base-patch16-224", |
| ] |
|
|
| for candidate in candidates: |
| if not candidate: |
| continue |
| if candidate == str(local_model_dir) and not local_model_dir.exists(): |
| continue |
| try: |
| clf = pipeline(task="image-classification", model=candidate) |
| return clf, candidate |
| except Exception: |
| continue |
|
|
| |
| fallback = "google/vit-base-patch16-224" |
| return pipeline(task="image-classification", model=fallback), fallback |
|
|
| @staticmethod |
| def _to_top_k_dict(results: List[Dict], k: int = 3) -> Dict[str, float]: |
| return { |
| result["label"]: round(float(result["score"]), 4) |
| for result in results[:k] |
| } |
|
|
| def classify_with_openai(self, image_path: str) -> Dict: |
| if self.openai_client is None: |
| return { |
| "error": "OPENAI_API_KEY is missing. Add it as an environment variable or HF Space secret.", |
| } |
|
|
| prompt = ( |
| "Classify the image into one label from the following list: " |
| f"{', '.join(self.labels)}. " |
| "Return valid JSON with exactly these keys: label, confidence, reasoning. " |
| "confidence must be a numeric value between 0 and 1." |
| ) |
|
|
| base64_image = self._encode_image(image_path) |
| response = self.openai_client.responses.create( |
| model=self.openai_model, |
| input=[ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "input_text", "text": prompt}, |
| { |
| "type": "input_image", |
| "image_url": f"data:image/jpeg;base64,{base64_image}", |
| }, |
| ], |
| } |
| ], |
| ) |
|
|
| try: |
| parsed = json.loads(response.output_text) |
| except json.JSONDecodeError: |
| parsed = { |
| "label": "unknown", |
| "confidence": 0.0, |
| "reasoning": response.output_text, |
| "warning": "OpenAI response was not valid JSON.", |
| } |
|
|
| return parsed |
|
|
| def classify_all(self, image_path: str) -> Dict: |
| custom_results = self.custom_classifier(image_path) |
| clip_results = self.clip_detector(image_path, candidate_labels=self.labels) |
| openai_results = self.classify_with_openai(image_path) |
|
|
| return { |
| "Custom Transfer Learning Model": { |
| "model": self.custom_model_name, |
| "top_3": self._to_top_k_dict(custom_results, k=3), |
| }, |
| "Open-Source Zero-Shot (CLIP)": { |
| "model": os.getenv("CLIP_MODEL_ID", "openai/clip-vit-large-patch14"), |
| "top_3": self._to_top_k_dict(clip_results, k=3), |
| }, |
| "Closed-Source Vision Model (OpenAI)": { |
| "model": self.openai_model, |
| "prediction": openai_results, |
| }, |
| } |
|
|
|
|
| def discover_example_images(example_dir: str = "example_images") -> List[List[str]]: |
| path = Path(example_dir) |
| if not path.is_absolute(): |
| path = PROJECT_DIR / path |
|
|
| if not path.exists(): |
| return [] |
|
|
| images = sorted( |
| [ |
| p for p in path.iterdir() |
| if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp"} |
| ] |
| ) |
| return [[str(p)] for p in images] |
|
|
|
|
| def classify_for_table(comparison: ModelComparison, image_path: str) -> Tuple[str, str, str]: |
| result = comparison.classify_all(image_path) |
|
|
| custom_top = result["Custom Transfer Learning Model"]["top_3"] |
| clip_top = result["Open-Source Zero-Shot (CLIP)"]["top_3"] |
| openai_pred = result["Closed-Source Vision Model (OpenAI)"]["prediction"] |
|
|
| custom_str = "; ".join([f"{k}: {v}" for k, v in custom_top.items()]) |
| clip_str = "; ".join([f"{k}: {v}" for k, v in clip_top.items()]) |
|
|
| if isinstance(openai_pred, dict): |
| openai_label = str(openai_pred.get("label", "unknown")) |
| openai_conf = str(openai_pred.get("confidence", "n/a")) |
| openai_str = f"{openai_label}: {openai_conf}" |
| else: |
| openai_str = str(openai_pred) |
|
|
| return custom_str, clip_str, openai_str |
|
|