Gilfoyle-alised's picture
fix error
5bbd9c7
Raw
History Blame Contribute Delete
7.05 kB
from __future__ import annotations
import gc
import os
from pathlib import Path
from typing import Dict, List
import gradio as gr
import torch
from PIL import Image
from custom_caption_model import LoadedCustomModel, load_custom_model
torch.set_num_threads(max(1, min(4, os.cpu_count() or 1)))
ROOT = Path(__file__).resolve().parent
CUSTOM_5K_DIR = ROOT / "models" / "custom_5k"
CUSTOM_100K_DIR = ROOT / "models" / "custom_100k"
BLIP_LORA_DIR = ROOT / "models" / "blip_lora"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_CHOICES = [
"Custom EfficientNet + Transformer — 5k",
"Custom EfficientNet + Transformer — 100k",
"BLIP + LoRA — COCO 2014",
]
_custom_cache: Dict[str, LoadedCustomModel] = {}
_blip_processor = None
_blip_model = None
def _load_custom_5k() -> LoadedCustomModel:
key = "custom_5k"
if key not in _custom_cache:
_custom_cache[key] = load_custom_model(
checkpoint_path=CUSTOM_5K_DIR / "best_phase-5k.pt",
vocab_path=CUSTOM_5K_DIR / "vocab-5k.json",
device=DEVICE,
)
return _custom_cache[key]
def _load_custom_100k() -> LoadedCustomModel:
key = "custom_100k"
if key not in _custom_cache:
_custom_cache[key] = load_custom_model(
checkpoint_path=CUSTOM_100K_DIR / "best_phase-100k.pt",
vocab_path=CUSTOM_100K_DIR / "vocab-100k.json",
device=DEVICE,
)
return _custom_cache[key]
def _load_blip_lora():
global _blip_processor, _blip_model
if _blip_model is None or _blip_processor is None:
from transformers import BlipForConditionalGeneration, BlipProcessor
from peft import PeftModel
base_model_id = "Salesforce/blip-image-captioning-base"
_blip_processor = BlipProcessor.from_pretrained(str(BLIP_LORA_DIR))
base_model = BlipForConditionalGeneration.from_pretrained(base_model_id)
_blip_model = PeftModel.from_pretrained(base_model, str(BLIP_LORA_DIR))
_blip_model = _blip_model.to(DEVICE)
_blip_model.eval()
return _blip_processor, _blip_model
def _caption_blip_lora(image: Image.Image) -> str:
processor, model = _load_blip_lora()
image = image.convert("RGB")
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
with torch.inference_mode():
output = model.generate(**inputs, max_new_tokens=50, num_beams=4)
return processor.decode(output[0], skip_special_tokens=True).strip()
def caption_one(image: Image.Image, model_choice: str, decoding: str) -> str:
if image is None:
return "Please upload an image first."
try:
if model_choice == "Custom EfficientNet + Transformer — 5k":
model = _load_custom_5k()
return model.caption(image, decoding=decoding)
if model_choice == "Custom EfficientNet + Transformer — 100k":
model = _load_custom_100k()
return model.caption(image, decoding=decoding)
if model_choice == "BLIP + LoRA — COCO 2014":
return _caption_blip_lora(image)
return "Unknown model selected."
except Exception as exc:
return f"Error: {type(exc).__name__}: {exc}"
def compare_all(image: Image.Image) -> str:
if image is None:
return "Please upload an image first."
rows: List[str] = []
for choice in MODEL_CHOICES:
caption = caption_one(image, choice, "Beam search")
rows.append(f"**{choice}**\n\n> {caption}")
return "\n\n---\n\n".join(rows)
def unload_models() -> str:
global _blip_processor, _blip_model
_custom_cache.clear()
_blip_processor = None
_blip_model = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
return "All cached models unloaded."
def update_decoding_visibility(model_choice: str):
return gr.update(visible="Custom" in model_choice)
CSS = """
.header-text { text-align: center; margin-bottom: 8px; }
.output-box textarea { font-size: 1.05em !important; line-height: 1.7 !important; }
.tag-row { display: flex; gap: 8px; flex-wrap: wrap; margin-top: 4px; }
footer { display: none !important; }
"""
with gr.Blocks(title="Image Captioning") as demo:
gr.Markdown(
"""
# 🖼️ Image Captioning
Upload a photo and generate a natural-language description using one of three trained models — or compare all at once.
""",
elem_classes=["header-text"],
)
with gr.Row(equal_height=False):
with gr.Column(scale=5):
image_input = gr.Image(
type="pil",
label="Upload Image",
height=360,
)
with gr.Column(scale=4):
model_dropdown = gr.Dropdown(
choices=MODEL_CHOICES,
value=MODEL_CHOICES[2],
label="Model",
info="BLIP + LoRA produces the best captions",
)
decoding_dropdown = gr.Dropdown(
choices=["Beam search", "Greedy"],
value="Beam search",
label="Decoding strategy",
info="Beam search is slower but produces better results",
visible=False,
)
gr.Markdown(
"""
<div style="font-size:0.82em; color:#6b7280; margin-top:4px;">
<b>Model details</b><br>
• <b>Custom 5k / 100k</b> — EfficientNet-V2-S + Transformer, trained from scratch on COCO subsets<br>
• <b>BLIP + LoRA</b> — Salesforce BLIP base fine-tuned with LoRA adapters on COCO 2014
</div>
"""
)
with gr.Row():
generate_btn = gr.Button(
"Generate Caption", variant="primary", scale=3, size="lg"
)
compare_btn = gr.Button("Compare All", scale=2, size="lg")
with gr.Group():
output = gr.Markdown(
value="",
label="Caption",
elem_classes=["output-box"],
)
with gr.Accordion("Advanced", open=False):
unload_btn = gr.Button("Unload Cached Models", variant="stop", size="sm")
unload_status = gr.Textbox(label="Status", lines=1, interactive=False)
unload_btn.click(fn=unload_models, inputs=None, outputs=unload_status)
gr.Markdown(
f"<div style='text-align:center; font-size:0.8em; color:#9ca3af; margin-top:8px;'>"
f"Runtime device: <code>{DEVICE}</code> · First inference per model is slower (lazy loading)"
f"</div>"
)
model_dropdown.change(
fn=update_decoding_visibility,
inputs=model_dropdown,
outputs=decoding_dropdown,
)
generate_btn.click(
fn=caption_one,
inputs=[image_input, model_dropdown, decoding_dropdown],
outputs=output,
)
compare_btn.click(
fn=compare_all,
inputs=[image_input],
outputs=output,
)
if __name__ == "__main__":
demo.launch(
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="slate"),
css=CSS,
)