File size: 5,729 Bytes
16c7630 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | import base64
import json
import os
from pathlib import Path
from typing import Dict, List, Tuple
from dotenv import load_dotenv
from openai import OpenAI
from transformers import pipeline
from labels import load_labels
load_dotenv()
PROJECT_DIR = Path(__file__).resolve().parent
class ModelComparison:
def __init__(self) -> None:
self.labels = load_labels()
self.openai_model = os.getenv("OPENAI_MODEL", "gpt-4.1-mini")
self.openai_client = self._create_openai_client()
self.custom_classifier, self.custom_model_name = self._load_custom_classifier()
self.clip_detector = pipeline(
task="zero-shot-image-classification",
model=os.getenv("CLIP_MODEL_ID", "openai/clip-vit-large-patch14"),
)
@staticmethod
def _create_openai_client() -> OpenAI | None:
api_key = os.getenv("OPENAI_API_KEY")
return OpenAI(api_key=api_key) if api_key else None
@staticmethod
def _encode_image(image_path: str) -> str:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")
def _load_custom_classifier(self):
local_model_dir = PROJECT_DIR / "models" / "custom-vit-model"
candidates = [
os.getenv("HF_MODEL_ID"),
str(local_model_dir),
"google/vit-base-patch16-224",
]
for candidate in candidates:
if not candidate:
continue
if candidate == str(local_model_dir) and not local_model_dir.exists():
continue
try:
clf = pipeline(task="image-classification", model=candidate)
return clf, candidate
except Exception:
continue
# Last fallback should always exist in practice.
fallback = "google/vit-base-patch16-224"
return pipeline(task="image-classification", model=fallback), fallback
@staticmethod
def _to_top_k_dict(results: List[Dict], k: int = 3) -> Dict[str, float]:
return {
result["label"]: round(float(result["score"]), 4)
for result in results[:k]
}
def classify_with_openai(self, image_path: str) -> Dict:
if self.openai_client is None:
return {
"error": "OPENAI_API_KEY is missing. Add it as an environment variable or HF Space secret.",
}
prompt = (
"Classify the image into one label from the following list: "
f"{', '.join(self.labels)}. "
"Return valid JSON with exactly these keys: label, confidence, reasoning. "
"confidence must be a numeric value between 0 and 1."
)
base64_image = self._encode_image(image_path)
response = self.openai_client.responses.create(
model=self.openai_model,
input=[
{
"role": "user",
"content": [
{"type": "input_text", "text": prompt},
{
"type": "input_image",
"image_url": f"data:image/jpeg;base64,{base64_image}",
},
],
}
],
)
try:
parsed = json.loads(response.output_text)
except json.JSONDecodeError:
parsed = {
"label": "unknown",
"confidence": 0.0,
"reasoning": response.output_text,
"warning": "OpenAI response was not valid JSON.",
}
return parsed
def classify_all(self, image_path: str) -> Dict:
custom_results = self.custom_classifier(image_path)
clip_results = self.clip_detector(image_path, candidate_labels=self.labels)
openai_results = self.classify_with_openai(image_path)
return {
"Custom Transfer Learning Model": {
"model": self.custom_model_name,
"top_3": self._to_top_k_dict(custom_results, k=3),
},
"Open-Source Zero-Shot (CLIP)": {
"model": os.getenv("CLIP_MODEL_ID", "openai/clip-vit-large-patch14"),
"top_3": self._to_top_k_dict(clip_results, k=3),
},
"Closed-Source Vision Model (OpenAI)": {
"model": self.openai_model,
"prediction": openai_results,
},
}
def discover_example_images(example_dir: str = "example_images") -> List[List[str]]:
path = Path(example_dir)
if not path.is_absolute():
path = PROJECT_DIR / path
if not path.exists():
return []
images = sorted(
[
p for p in path.iterdir()
if p.suffix.lower() in {".jpg", ".jpeg", ".png", ".webp"}
]
)
return [[str(p)] for p in images]
def classify_for_table(comparison: ModelComparison, image_path: str) -> Tuple[str, str, str]:
result = comparison.classify_all(image_path)
custom_top = result["Custom Transfer Learning Model"]["top_3"]
clip_top = result["Open-Source Zero-Shot (CLIP)"]["top_3"]
openai_pred = result["Closed-Source Vision Model (OpenAI)"]["prediction"]
custom_str = "; ".join([f"{k}: {v}" for k, v in custom_top.items()])
clip_str = "; ".join([f"{k}: {v}" for k, v in clip_top.items()])
if isinstance(openai_pred, dict):
openai_label = str(openai_pred.get("label", "unknown"))
openai_conf = str(openai_pred.get("confidence", "n/a"))
openai_str = f"{openai_label}: {openai_conf}"
else:
openai_str = str(openai_pred)
return custom_str, clip_str, openai_str
|