from __future__ import annotations import gc import os from pathlib import Path from typing import Dict, List import gradio as gr import torch from PIL import Image from custom_caption_model import LoadedCustomModel, load_custom_model torch.set_num_threads(max(1, min(4, os.cpu_count() or 1))) ROOT = Path(__file__).resolve().parent CUSTOM_5K_DIR = ROOT / "models" / "custom_5k" CUSTOM_100K_DIR = ROOT / "models" / "custom_100k" BLIP_LORA_DIR = ROOT / "models" / "blip_lora" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_CHOICES = [ "Custom EfficientNet + Transformer — 5k", "Custom EfficientNet + Transformer — 100k", "BLIP + LoRA — COCO 2014", ] _custom_cache: Dict[str, LoadedCustomModel] = {} _blip_processor = None _blip_model = None def _load_custom_5k() -> LoadedCustomModel: key = "custom_5k" if key not in _custom_cache: _custom_cache[key] = load_custom_model( checkpoint_path=CUSTOM_5K_DIR / "best_phase-5k.pt", vocab_path=CUSTOM_5K_DIR / "vocab-5k.json", device=DEVICE, ) return _custom_cache[key] def _load_custom_100k() -> LoadedCustomModel: key = "custom_100k" if key not in _custom_cache: _custom_cache[key] = load_custom_model( checkpoint_path=CUSTOM_100K_DIR / "best_phase-100k.pt", vocab_path=CUSTOM_100K_DIR / "vocab-100k.json", device=DEVICE, ) return _custom_cache[key] def _load_blip_lora(): global _blip_processor, _blip_model if _blip_model is None or _blip_processor is None: from transformers import BlipForConditionalGeneration, BlipProcessor from peft import PeftModel base_model_id = "Salesforce/blip-image-captioning-base" _blip_processor = BlipProcessor.from_pretrained(str(BLIP_LORA_DIR)) base_model = BlipForConditionalGeneration.from_pretrained(base_model_id) _blip_model = PeftModel.from_pretrained(base_model, str(BLIP_LORA_DIR)) _blip_model = _blip_model.to(DEVICE) _blip_model.eval() return _blip_processor, _blip_model def _caption_blip_lora(image: Image.Image) -> str: processor, model = _load_blip_lora() image = image.convert("RGB") inputs = processor(images=image, return_tensors="pt").to(DEVICE) with torch.inference_mode(): output = model.generate(**inputs, max_new_tokens=50, num_beams=4) return processor.decode(output[0], skip_special_tokens=True).strip() def caption_one(image: Image.Image, model_choice: str, decoding: str) -> str: if image is None: return "Please upload an image first." try: if model_choice == "Custom EfficientNet + Transformer — 5k": model = _load_custom_5k() return model.caption(image, decoding=decoding) if model_choice == "Custom EfficientNet + Transformer — 100k": model = _load_custom_100k() return model.caption(image, decoding=decoding) if model_choice == "BLIP + LoRA — COCO 2014": return _caption_blip_lora(image) return "Unknown model selected." except Exception as exc: return f"Error: {type(exc).__name__}: {exc}" def compare_all(image: Image.Image) -> str: if image is None: return "Please upload an image first." rows: List[str] = [] for choice in MODEL_CHOICES: caption = caption_one(image, choice, "Beam search") rows.append(f"**{choice}**\n\n> {caption}") return "\n\n---\n\n".join(rows) def unload_models() -> str: global _blip_processor, _blip_model _custom_cache.clear() _blip_processor = None _blip_model = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return "All cached models unloaded." def update_decoding_visibility(model_choice: str): return gr.update(visible="Custom" in model_choice) CSS = """ .header-text { text-align: center; margin-bottom: 8px; } .output-box textarea { font-size: 1.05em !important; line-height: 1.7 !important; } .tag-row { display: flex; gap: 8px; flex-wrap: wrap; margin-top: 4px; } footer { display: none !important; } """ with gr.Blocks(title="Image Captioning") as demo: gr.Markdown( """ # 🖼️ Image Captioning Upload a photo and generate a natural-language description using one of three trained models — or compare all at once. """, elem_classes=["header-text"], ) with gr.Row(equal_height=False): with gr.Column(scale=5): image_input = gr.Image( type="pil", label="Upload Image", height=360, ) with gr.Column(scale=4): model_dropdown = gr.Dropdown( choices=MODEL_CHOICES, value=MODEL_CHOICES[2], label="Model", info="BLIP + LoRA produces the best captions", ) decoding_dropdown = gr.Dropdown( choices=["Beam search", "Greedy"], value="Beam search", label="Decoding strategy", info="Beam search is slower but produces better results", visible=False, ) gr.Markdown( """
Model details
Custom 5k / 100k — EfficientNet-V2-S + Transformer, trained from scratch on COCO subsets
BLIP + LoRA — Salesforce BLIP base fine-tuned with LoRA adapters on COCO 2014
""" ) with gr.Row(): generate_btn = gr.Button( "Generate Caption", variant="primary", scale=3, size="lg" ) compare_btn = gr.Button("Compare All", scale=2, size="lg") with gr.Group(): output = gr.Markdown( value="", label="Caption", elem_classes=["output-box"], ) with gr.Accordion("Advanced", open=False): unload_btn = gr.Button("Unload Cached Models", variant="stop", size="sm") unload_status = gr.Textbox(label="Status", lines=1, interactive=False) unload_btn.click(fn=unload_models, inputs=None, outputs=unload_status) gr.Markdown( f"
" f"Runtime device: {DEVICE} · First inference per model is slower (lazy loading)" f"
" ) model_dropdown.change( fn=update_decoding_visibility, inputs=model_dropdown, outputs=decoding_dropdown, ) generate_btn.click( fn=caption_one, inputs=[image_input, model_dropdown, decoding_dropdown], outputs=output, ) compare_btn.click( fn=compare_all, inputs=[image_input], outputs=output, ) if __name__ == "__main__": demo.launch( theme=gr.themes.Soft(primary_hue="blue", secondary_hue="slate"), css=CSS, )