Spaces:
Sleeping
Sleeping
| """ | |
| Computer Vision Classification & Model Comparison App | |
| Compares: | |
| 1) Custom transfer-learning classifier (trained on custom image data) | |
| 2) Open-source CLIP zero-shot classifier | |
| 3) OpenAI vision model classifier | |
| """ | |
| import base64 | |
| import json | |
| import os | |
| import sys | |
| from io import BytesIO | |
| from typing import Dict, List, Tuple | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from openai import OpenAI | |
| from PIL import Image | |
| from torchvision import models, transforms | |
| from transformers import CLIPModel, CLIPProcessor | |
| from transformers.utils import logging as hf_logging | |
| hf_logging.set_verbosity_error() | |
| def _silence_known_asyncio_shutdown_warning(unraisable) -> None: | |
| exc = getattr(unraisable, "exc_value", None) | |
| if isinstance(exc, ValueError) and "Invalid file descriptor: -1" in str(exc): | |
| return | |
| sys.__unraisablehook__(unraisable) | |
| sys.unraisablehook = _silence_known_asyncio_shutdown_warning | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| DEFAULT_CLASS_NAMES = [ | |
| "audi", | |
| "bmw", | |
| "chevrolet", | |
| "ford", | |
| "honda", | |
| "hyundai", | |
| "mercedes-benz", | |
| "nissan", | |
| "tesla", | |
| "toyota", | |
| "volkswagen", | |
| "dodge", | |
| ] | |
| MODEL_DIR = os.getenv("CUSTOM_MODEL_DIR", os.path.join(BASE_DIR, "models", "custom_resnet18")) | |
| MODEL_WEIGHTS_PATH = os.path.join(MODEL_DIR, "model.pth") | |
| MODEL_WEIGHTS_B64_PATH = os.path.join(MODEL_DIR, "model.pth.b64") | |
| CLASS_NAMES_PATH = os.path.join(MODEL_DIR, "class_names.json") | |
| MODEL_CONFIG_PATH = os.path.join(MODEL_DIR, "model_config.json") | |
| OPENAI_MODEL_NAME = os.getenv("OPENAI_MODEL", "gpt-4o-mini") | |
| HF_MODEL_REPO = os.getenv("HF_MODEL_REPO", "") | |
| def load_class_names() -> List[str]: | |
| if os.path.exists(CLASS_NAMES_PATH): | |
| with open(CLASS_NAMES_PATH, "r", encoding="utf-8") as f: | |
| names = json.load(f) | |
| if isinstance(names, list) and len(names) > 0: | |
| return names | |
| return DEFAULT_CLASS_NAMES | |
| CLASS_NAMES = load_class_names() | |
| CUSTOM_TRANSFORM = transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| custom_model = None | |
| custom_model_error = None | |
| def load_model_arch() -> str: | |
| if os.path.exists(MODEL_CONFIG_PATH): | |
| try: | |
| with open(MODEL_CONFIG_PATH, "r", encoding="utf-8") as f: | |
| cfg = json.load(f) | |
| arch = cfg.get("arch") | |
| if arch in {"resnet18", "squeezenet1_1"}: | |
| return arch | |
| except Exception: | |
| pass | |
| return os.getenv("CUSTOM_MODEL_ARCH", "squeezenet1_1") | |
| def build_custom_model(arch: str, num_classes: int) -> torch.nn.Module: | |
| if arch == "resnet18": | |
| model = models.resnet18(weights=None) | |
| model.fc = torch.nn.Linear(model.fc.in_features, num_classes) | |
| return model | |
| if arch == "squeezenet1_1": | |
| model = models.squeezenet1_1(weights=None) | |
| model.classifier[1] = torch.nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1)) | |
| model.num_classes = num_classes | |
| return model | |
| raise ValueError(f"Unsupported custom model arch: {arch}") | |
| def _maybe_download_model_from_hf() -> None: | |
| if os.path.exists(MODEL_WEIGHTS_PATH): | |
| return | |
| if not HF_MODEL_REPO: | |
| return | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| snapshot_download(repo_id=HF_MODEL_REPO, local_dir=MODEL_DIR, local_dir_use_symlinks=False) | |
| try: | |
| _maybe_download_model_from_hf() | |
| except Exception as exc: # pragma: no cover | |
| custom_model_error = f"Failed to download custom model from Hugging Face repo '{HF_MODEL_REPO}': {exc}" | |
| if os.path.exists(MODEL_WEIGHTS_PATH): | |
| try: | |
| model_arch = load_model_arch() | |
| custom_model = build_custom_model(model_arch, len(CLASS_NAMES)) | |
| state_dict = torch.load(MODEL_WEIGHTS_PATH, map_location="cpu") | |
| custom_model.load_state_dict(state_dict) | |
| custom_model.eval() | |
| except Exception as exc: # pragma: no cover | |
| custom_model_error = str(exc) | |
| elif os.path.exists(MODEL_WEIGHTS_B64_PATH): | |
| try: | |
| model_arch = load_model_arch() | |
| custom_model = build_custom_model(model_arch, len(CLASS_NAMES)) | |
| with open(MODEL_WEIGHTS_B64_PATH, "r", encoding="utf-8") as f: | |
| encoded = f.read().strip() | |
| decoded = base64.b64decode(encoded) | |
| state_dict = torch.load(BytesIO(decoded), map_location="cpu") | |
| custom_model.load_state_dict(state_dict) | |
| custom_model.eval() | |
| except Exception as exc: # pragma: no cover | |
| custom_model_error = f"Failed to load model from text weights {MODEL_WEIGHTS_B64_PATH}: {exc}" | |
| else: | |
| custom_model_error = ( | |
| f"Custom model not found at {MODEL_WEIGHTS_PATH} or {MODEL_WEIGHTS_B64_PATH}. " | |
| "Train first with train_transfer_learning.py and export model.pth.b64, or set HF_MODEL_REPO to auto-download." | |
| ) | |
| clip_model = None | |
| clip_processor = None | |
| clip_error = None | |
| try: | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| except Exception as exc: # pragma: no cover | |
| clip_error = str(exc) | |
| def _topk(probs: torch.Tensor, labels: List[str], k: int = 3) -> List[Tuple[str, float]]: | |
| values, indices = torch.topk(probs, k=min(k, len(labels))) | |
| return [(labels[i], float(v)) for v, i in zip(values, indices)] | |
| def _sorted_probabilities(labels: List[str], probs: torch.Tensor) -> Dict[str, float]: | |
| pairs = [(labels[i], float(probs[i])) for i in range(len(labels))] | |
| pairs.sort(key=lambda x: x[1], reverse=True) | |
| return {label: score for label, score in pairs} | |
| def predict_custom(image: Image.Image, clip_probabilities: Dict[str, float] | None = None) -> Dict: | |
| if custom_model is None: | |
| return {"error": custom_model_error} | |
| x = CUSTOM_TRANSFORM(image.convert("RGB")).unsqueeze(0) | |
| with torch.no_grad(): | |
| logits = custom_model(x) | |
| probs = torch.softmax(logits[0], dim=0) | |
| note = "" | |
| max_prob = float(torch.max(probs)) | |
| if clip_probabilities and max_prob < 0.20: | |
| clip_prior = torch.tensor( | |
| [float(clip_probabilities.get(label, 0.0)) for label in CLASS_NAMES], dtype=probs.dtype | |
| ) | |
| prior_sum = float(torch.sum(clip_prior)) | |
| if prior_sum > 0: | |
| clip_prior = clip_prior / prior_sum | |
| probs = 0.2 * probs + 0.8 * clip_prior | |
| note = "Low-confidence custom model output blended with CLIP prior." | |
| top3 = _topk(probs, CLASS_NAMES, k=3) | |
| result = { | |
| "top1": top3[0][0], | |
| "top3": top3, | |
| "probabilities": _sorted_probabilities(CLASS_NAMES, probs), | |
| } | |
| if note: | |
| result["note"] = note | |
| return result | |
| def predict_clip(image: Image.Image) -> Dict: | |
| if clip_model is None or clip_processor is None: | |
| return {"error": clip_error} | |
| prompts = [f"a photo of {name}" for name in CLASS_NAMES] | |
| inputs = clip_processor(text=prompts, images=image.convert("RGB"), return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| logits = clip_model(**inputs).logits_per_image[0] | |
| probs = torch.softmax(logits, dim=0) | |
| top3 = _topk(probs, CLASS_NAMES, k=3) | |
| return { | |
| "top1": top3[0][0], | |
| "top3": top3, | |
| "probabilities": _sorted_probabilities(CLASS_NAMES, probs), | |
| } | |
| def predict_openai(image: Image.Image) -> Dict: | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| return { | |
| "error": "OPENAI_API_KEY is not set. Add it in Hugging Face Space Settings > Secrets." | |
| } | |
| client = OpenAI(api_key=api_key) | |
| buffered = BytesIO() | |
| image.convert("RGB").save(buffered, format="PNG") | |
| b64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| prompt = ( | |
| "Classify this image into exactly one label from this set: " | |
| + ", ".join(CLASS_NAMES) | |
| + ". Return strict JSON with keys: label, confidence, reason. " | |
| "confidence must be a float between 0 and 1." | |
| ) | |
| try: | |
| response = client.responses.create( | |
| model=OPENAI_MODEL_NAME, | |
| temperature=0, | |
| input=[ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "input_text", "text": prompt}, | |
| {"type": "input_image", "image_url": f"data:image/png;base64,{b64_image}"}, | |
| ], | |
| } | |
| ], | |
| ) | |
| text = response.output_text.strip() | |
| start = text.find("{") | |
| end = text.rfind("}") | |
| payload = json.loads(text[start : end + 1]) if start >= 0 and end >= 0 else {} | |
| label = str(payload.get("label", "unknown")).lower() | |
| confidence = float(payload.get("confidence", 0.0)) | |
| reason = str(payload.get("reason", "")) | |
| if label not in CLASS_NAMES: | |
| label = "out-of-set" | |
| return { | |
| "top1": label, | |
| "top3": [(label, confidence)], | |
| "reason": reason, | |
| } | |
| except Exception as exc: # pragma: no cover | |
| return {"error": str(exc)} | |
| def compare_models(image: Image.Image): | |
| if image is None: | |
| return { | |
| "error": "Please upload an image.", | |
| "classes": CLASS_NAMES, | |
| }, "" | |
| clip = predict_clip(image) | |
| custom = predict_custom(image, clip.get("probabilities") if "error" not in clip else None) | |
| openai = predict_openai(image) | |
| result = { | |
| "classes": CLASS_NAMES, | |
| "Custom Transfer Model": ( | |
| {"error": custom["error"]} | |
| if "error" in custom | |
| else { | |
| "top1": custom["top1"], | |
| "top3": [{"label": l, "score": s} for l, s in custom["top3"]], | |
| "probabilities": custom["probabilities"], | |
| "note": custom.get("note", ""), | |
| } | |
| ), | |
| "CLIP Zero-Shot Classification": ( | |
| {"error": clip["error"]} | |
| if "error" in clip | |
| else { | |
| "top1": clip["top1"], | |
| "top3": [{"label": l, "score": s} for l, s in clip["top3"]], | |
| "probabilities": clip["probabilities"], | |
| } | |
| ), | |
| "OpenAI Vision Classification": ( | |
| {"error": openai["error"]} | |
| if "error" in openai | |
| else { | |
| "label": openai["top1"], | |
| "confidence": openai["top3"][0][1] if openai.get("top3") else 0.0, | |
| "reason": openai.get("reason", ""), | |
| } | |
| ), | |
| } | |
| summary = [] | |
| if "error" not in custom: | |
| summary.append(f"Custom: {custom['top1']}") | |
| if "error" not in clip: | |
| summary.append(f"CLIP: {clip['top1']}") | |
| if "error" not in openai: | |
| summary.append(f"OpenAI: {openai['top1']}") | |
| return result, " | ".join(summary) | |
| def clear_outputs(): | |
| return None, {}, "" | |
| def build_examples() -> List[List[str]]: | |
| examples_dir = os.path.join(BASE_DIR, "example_images") | |
| files = [] | |
| if not os.path.isdir(examples_dir): | |
| return files | |
| for name in sorted(os.listdir(examples_dir)): | |
| lower = name.lower() | |
| if lower.endswith((".png", ".jpg", ".jpeg", ".webp")): | |
| files.append([os.path.join(examples_dir, name)]) | |
| return files[:6] | |
| with gr.Blocks(title="Car Classification Comparison") as demo: | |
| gr.Markdown( | |
| """ | |
| # Car Classification Comparison | |
| This app compares three approaches on the same uploaded image: | |
| - Custom transfer-learning model (trained on your custom classes) | |
| - Open-source CLIP zero-shot model | |
| - Closed-source OpenAI vision model | |
| This Space is configured for car classification only. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=6): | |
| image_input = gr.Image(type="pil", label="Upload an image") | |
| with gr.Row(): | |
| clear_button = gr.Button("Clear") | |
| run_button = gr.Button("Submit", variant="primary") | |
| with gr.Column(scale=6): | |
| comparison_json = gr.JSON(label="Comparison output") | |
| summary_text = gr.Textbox(label="Top predictions", interactive=False) | |
| run_button.click( | |
| fn=compare_models, | |
| inputs=[image_input], | |
| outputs=[comparison_json, summary_text], | |
| ) | |
| clear_button.click( | |
| fn=clear_outputs, | |
| inputs=[], | |
| outputs=[image_input, comparison_json, summary_text], | |
| ) | |
| gr.Examples( | |
| examples=build_examples(), | |
| inputs=[image_input], | |
| label="Example images", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) | |