from __future__ import annotations import gc import os from pathlib import Path from typing import Dict, List import gradio as gr import torch from PIL import Image from custom_caption_model import LoadedCustomModel, load_custom_model torch.set_num_threads(max(1, min(4, os.cpu_count() or 1))) ROOT = Path(__file__).resolve().parent CUSTOM_5K_DIR = ROOT / "models" / "custom_5k" CUSTOM_100K_DIR = ROOT / "models" / "custom_100k" BLIP_LORA_DIR = ROOT / "models" / "blip_lora" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_CHOICES = [ "Custom EfficientNet + Transformer — 5k", "Custom EfficientNet + Transformer — 100k", "BLIP + LoRA — COCO 2014", ] _custom_cache: Dict[str, LoadedCustomModel] = {} _blip_processor = None _blip_model = None def _load_custom_5k() -> LoadedCustomModel: key = "custom_5k" if key not in _custom_cache: _custom_cache[key] = load_custom_model( checkpoint_path=CUSTOM_5K_DIR / "best_phase-5k.pt", vocab_path=CUSTOM_5K_DIR / "vocab-5k.json", device=DEVICE, ) return _custom_cache[key] def _load_custom_100k() -> LoadedCustomModel: key = "custom_100k" if key not in _custom_cache: _custom_cache[key] = load_custom_model( checkpoint_path=CUSTOM_100K_DIR / "best_phase-100k.pt", vocab_path=CUSTOM_100K_DIR / "vocab-100k.json", device=DEVICE, ) return _custom_cache[key] def _load_blip_lora(): global _blip_processor, _blip_model if _blip_model is None or _blip_processor is None: from transformers import BlipForConditionalGeneration, BlipProcessor from peft import PeftModel base_model_id = "Salesforce/blip-image-captioning-base" _blip_processor = BlipProcessor.from_pretrained(str(BLIP_LORA_DIR)) base_model = BlipForConditionalGeneration.from_pretrained(base_model_id) _blip_model = PeftModel.from_pretrained(base_model, str(BLIP_LORA_DIR)) _blip_model = _blip_model.to(DEVICE) _blip_model.eval() return _blip_processor, _blip_model def _caption_blip_lora(image: Image.Image) -> str: processor, model = _load_blip_lora() image = image.convert("RGB") inputs = processor(images=image, return_tensors="pt").to(DEVICE) with torch.inference_mode(): output = model.generate(**inputs, max_new_tokens=50, num_beams=4) return processor.decode(output[0], skip_special_tokens=True).strip() def caption_one(image: Image.Image, model_choice: str, decoding: str) -> str: if image is None: return "Please upload an image first." try: if model_choice == "Custom EfficientNet + Transformer — 5k": model = _load_custom_5k() return model.caption(image, decoding=decoding) if model_choice == "Custom EfficientNet + Transformer — 100k": model = _load_custom_100k() return model.caption(image, decoding=decoding) if model_choice == "BLIP + LoRA — COCO 2014": return _caption_blip_lora(image) return "Unknown model selected." except Exception as exc: return f"Error: {type(exc).__name__}: {exc}" def compare_all(image: Image.Image) -> str: if image is None: return "Please upload an image first." rows: List[str] = [] for choice in MODEL_CHOICES: caption = caption_one(image, choice, "Beam search") rows.append(f"**{choice}**\n\n> {caption}") return "\n\n---\n\n".join(rows) def unload_models() -> str: global _blip_processor, _blip_model _custom_cache.clear() _blip_processor = None _blip_model = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return "All cached models unloaded." def update_decoding_visibility(model_choice: str): return gr.update(visible="Custom" in model_choice) CSS = """ .header-text { text-align: center; margin-bottom: 8px; } .output-box textarea { font-size: 1.05em !important; line-height: 1.7 !important; } .tag-row { display: flex; gap: 8px; flex-wrap: wrap; margin-top: 4px; } footer { display: none !important; } """ with gr.Blocks(title="Image Captioning") as demo: gr.Markdown( """ # 🖼️ Image Captioning Upload a photo and generate a natural-language description using one of three trained models — or compare all at once. """, elem_classes=["header-text"], ) with gr.Row(equal_height=False): with gr.Column(scale=5): image_input = gr.Image( type="pil", label="Upload Image", height=360, ) with gr.Column(scale=4): model_dropdown = gr.Dropdown( choices=MODEL_CHOICES, value=MODEL_CHOICES[2], label="Model", info="BLIP + LoRA produces the best captions", ) decoding_dropdown = gr.Dropdown( choices=["Beam search", "Greedy"], value="Beam search", label="Decoding strategy", info="Beam search is slower but produces better results", visible=False, ) gr.Markdown( """
{DEVICE} · First inference per model is slower (lazy loading)"
f"