| from __future__ import annotations |
|
|
| import gc |
| import os |
| from pathlib import Path |
| from typing import Dict, List |
|
|
| import gradio as gr |
| import torch |
| from PIL import Image |
|
|
| from custom_caption_model import LoadedCustomModel, load_custom_model |
|
|
| torch.set_num_threads(max(1, min(4, os.cpu_count() or 1))) |
|
|
| ROOT = Path(__file__).resolve().parent |
| CUSTOM_5K_DIR = ROOT / "models" / "custom_5k" |
| CUSTOM_100K_DIR = ROOT / "models" / "custom_100k" |
| BLIP_LORA_DIR = ROOT / "models" / "blip_lora" |
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| MODEL_CHOICES = [ |
| "Custom EfficientNet + Transformer — 5k", |
| "Custom EfficientNet + Transformer — 100k", |
| "BLIP + LoRA — COCO 2014", |
| ] |
|
|
| _custom_cache: Dict[str, LoadedCustomModel] = {} |
| _blip_processor = None |
| _blip_model = None |
|
|
|
|
| def _load_custom_5k() -> LoadedCustomModel: |
| key = "custom_5k" |
| if key not in _custom_cache: |
| _custom_cache[key] = load_custom_model( |
| checkpoint_path=CUSTOM_5K_DIR / "best_phase-5k.pt", |
| vocab_path=CUSTOM_5K_DIR / "vocab-5k.json", |
| device=DEVICE, |
| ) |
| return _custom_cache[key] |
|
|
|
|
| def _load_custom_100k() -> LoadedCustomModel: |
| key = "custom_100k" |
| if key not in _custom_cache: |
| _custom_cache[key] = load_custom_model( |
| checkpoint_path=CUSTOM_100K_DIR / "best_phase-100k.pt", |
| vocab_path=CUSTOM_100K_DIR / "vocab-100k.json", |
| device=DEVICE, |
| ) |
| return _custom_cache[key] |
|
|
|
|
| def _load_blip_lora(): |
| global _blip_processor, _blip_model |
| if _blip_model is None or _blip_processor is None: |
| from transformers import BlipForConditionalGeneration, BlipProcessor |
| from peft import PeftModel |
|
|
| base_model_id = "Salesforce/blip-image-captioning-base" |
| _blip_processor = BlipProcessor.from_pretrained(str(BLIP_LORA_DIR)) |
| base_model = BlipForConditionalGeneration.from_pretrained(base_model_id) |
| _blip_model = PeftModel.from_pretrained(base_model, str(BLIP_LORA_DIR)) |
| _blip_model = _blip_model.to(DEVICE) |
| _blip_model.eval() |
| return _blip_processor, _blip_model |
|
|
|
|
| def _caption_blip_lora(image: Image.Image) -> str: |
| processor, model = _load_blip_lora() |
| image = image.convert("RGB") |
| inputs = processor(images=image, return_tensors="pt").to(DEVICE) |
| with torch.inference_mode(): |
| output = model.generate(**inputs, max_new_tokens=50, num_beams=4) |
| return processor.decode(output[0], skip_special_tokens=True).strip() |
|
|
|
|
| def caption_one(image: Image.Image, model_choice: str, decoding: str) -> str: |
| if image is None: |
| return "Please upload an image first." |
|
|
| try: |
| if model_choice == "Custom EfficientNet + Transformer — 5k": |
| model = _load_custom_5k() |
| return model.caption(image, decoding=decoding) |
|
|
| if model_choice == "Custom EfficientNet + Transformer — 100k": |
| model = _load_custom_100k() |
| return model.caption(image, decoding=decoding) |
|
|
| if model_choice == "BLIP + LoRA — COCO 2014": |
| return _caption_blip_lora(image) |
|
|
| return "Unknown model selected." |
| except Exception as exc: |
| return f"Error: {type(exc).__name__}: {exc}" |
|
|
|
|
| def compare_all(image: Image.Image) -> str: |
| if image is None: |
| return "Please upload an image first." |
|
|
| rows: List[str] = [] |
| for choice in MODEL_CHOICES: |
| caption = caption_one(image, choice, "Beam search") |
| rows.append(f"**{choice}**\n\n> {caption}") |
| return "\n\n---\n\n".join(rows) |
|
|
|
|
| def unload_models() -> str: |
| global _blip_processor, _blip_model |
| _custom_cache.clear() |
| _blip_processor = None |
| _blip_model = None |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| return "All cached models unloaded." |
|
|
|
|
| def update_decoding_visibility(model_choice: str): |
| return gr.update(visible="Custom" in model_choice) |
|
|
|
|
| CSS = """ |
| .header-text { text-align: center; margin-bottom: 8px; } |
| .output-box textarea { font-size: 1.05em !important; line-height: 1.7 !important; } |
| .tag-row { display: flex; gap: 8px; flex-wrap: wrap; margin-top: 4px; } |
| footer { display: none !important; } |
| """ |
|
|
| with gr.Blocks(title="Image Captioning") as demo: |
|
|
| gr.Markdown( |
| """ |
| # 🖼️ Image Captioning |
| Upload a photo and generate a natural-language description using one of three trained models — or compare all at once. |
| """, |
| elem_classes=["header-text"], |
| ) |
|
|
| with gr.Row(equal_height=False): |
| with gr.Column(scale=5): |
| image_input = gr.Image( |
| type="pil", |
| label="Upload Image", |
| height=360, |
| ) |
|
|
| with gr.Column(scale=4): |
| model_dropdown = gr.Dropdown( |
| choices=MODEL_CHOICES, |
| value=MODEL_CHOICES[2], |
| label="Model", |
| info="BLIP + LoRA produces the best captions", |
| ) |
| decoding_dropdown = gr.Dropdown( |
| choices=["Beam search", "Greedy"], |
| value="Beam search", |
| label="Decoding strategy", |
| info="Beam search is slower but produces better results", |
| visible=False, |
| ) |
|
|
| gr.Markdown( |
| """ |
| <div style="font-size:0.82em; color:#6b7280; margin-top:4px;"> |
| <b>Model details</b><br> |
| • <b>Custom 5k / 100k</b> — EfficientNet-V2-S + Transformer, trained from scratch on COCO subsets<br> |
| • <b>BLIP + LoRA</b> — Salesforce BLIP base fine-tuned with LoRA adapters on COCO 2014 |
| </div> |
| """ |
| ) |
|
|
| with gr.Row(): |
| generate_btn = gr.Button( |
| "Generate Caption", variant="primary", scale=3, size="lg" |
| ) |
| compare_btn = gr.Button("Compare All", scale=2, size="lg") |
|
|
| with gr.Group(): |
| output = gr.Markdown( |
| value="", |
| label="Caption", |
| elem_classes=["output-box"], |
| ) |
|
|
| with gr.Accordion("Advanced", open=False): |
| unload_btn = gr.Button("Unload Cached Models", variant="stop", size="sm") |
| unload_status = gr.Textbox(label="Status", lines=1, interactive=False) |
| unload_btn.click(fn=unload_models, inputs=None, outputs=unload_status) |
|
|
| gr.Markdown( |
| f"<div style='text-align:center; font-size:0.8em; color:#9ca3af; margin-top:8px;'>" |
| f"Runtime device: <code>{DEVICE}</code> · First inference per model is slower (lazy loading)" |
| f"</div>" |
| ) |
|
|
| model_dropdown.change( |
| fn=update_decoding_visibility, |
| inputs=model_dropdown, |
| outputs=decoding_dropdown, |
| ) |
| generate_btn.click( |
| fn=caption_one, |
| inputs=[image_input, model_dropdown, decoding_dropdown], |
| outputs=output, |
| ) |
| compare_btn.click( |
| fn=compare_all, |
| inputs=[image_input], |
| outputs=output, |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch( |
| theme=gr.themes.Soft(primary_hue="blue", secondary_hue="slate"), |
| css=CSS, |
| ) |
|
|