ImageCaptioning / app.py
MRaudhatul's picture
Update app.py
35d8749 verified
Raw
History Blame Contribute Delete
10.6 kB
import time
import torch
import pandas as pd
import gradio as gr
from transformers import (
InstructBlipProcessor,
InstructBlipForConditionalGeneration
)
# =====================================================
# MODEL
# =====================================================
MODEL_ID = "MRaudhatul/instructblip-coco-captioning"
print("Loading processor...")
processor = InstructBlipProcessor.from_pretrained(MODEL_ID)
print("Loading model...")
model = InstructBlipForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32
)
model.eval()
print("Model loaded successfully")
# =====================================================
# DATASET STATS
# =====================================================
TRAIN_IMAGES = 26613
VALID_IMAGES = 2958
TRAIN_CAPTIONS = 133142
VALID_CAPTIONS = 14794
# =====================================================
# METRICS
# =====================================================
BLEU1 = 0.7798
BLEU2 = 0.6066
BLEU3 = 0.4547
BLEU4 = 0.3290
ROUGE_L = 0.5909
METEOR = 0.5790
CIDER = 0.9931
metrics_df = pd.DataFrame({
"Metric": ["BLEU-1","BLEU-2","BLEU-3","BLEU-4","ROUGE-L","METEOR","CIDEr"],
"Score": [BLEU1, BLEU2, BLEU3, BLEU4, ROUGE_L, METEOR, CIDER]
})
# =====================================================
# TASK PROMPTS
# =====================================================
PROMPTS = {
"Generate Caption": "Describe this image.",
"Detailed Caption": "Describe this image in detail.",
"Identify Main Objects": "What are the main objects in this image?",
"Explain Scene": "Explain what is happening in this image."
}
# =====================================================
# CSS β€” fully responsive, mobile-first
# =====================================================
css = """
/* ── reset & base ── */
*, *::before, *::after { box-sizing: border-box; }
html, body {
width: 100% !important;
overflow-x: hidden !important;
}
.gradio-container {
min-width: unset !important;
width: 100% !important;
max-width: 960px !important;
margin: 0 auto !important;
padding: 0 12px !important;
}
/* paksa semua elemen tidak overflow */
.block, .form, .panel {
min-width: unset !important;
width: 100% !important;
}
footer { display: none !important; }
/* ── title block ── */
.main-title {
text-align: center;
font-size: clamp(22px, 5vw, 38px);
font-weight: 700;
margin: 16px 0 6px;
line-height: 1.2;
word-break: break-word;
}
.subtitle {
text-align: center;
font-size: clamp(13px, 3vw, 16px);
color: #666;
margin-bottom: 20px;
}
/* ── tab labels ── */
.tab-nav button {
font-size: clamp(12px, 2.5vw, 15px) !important;
padding: 8px 10px !important;
}
/* ── generate tab: stack on mobile, side-by-side on desktop ── */
.img-row {
display: flex;
flex-wrap: wrap;
gap: 12px;
}
.img-row > * {
flex: 1 1 280px;
min-width: 0;
}
/* ── stats row ── */
.stats-row {
display: flex;
flex-wrap: wrap;
gap: 10px;
}
.stats-row > * {
flex: 1 1 130px;
min-width: 0;
}
/* ── metrics row ── */
.metrics-row {
display: flex;
flex-wrap: wrap;
gap: 10px;
margin-bottom: 12px;
}
.metrics-row > * {
flex: 1 1 100px;
min-width: 0;
}
/* ── inputs & buttons ── */
button.primary {
width: 100% !important;
font-size: clamp(14px, 3vw, 16px) !important;
padding: 10px 16px !important;
}
/* ── labels: prevent overflow ── */
label span {
white-space: normal !important;
word-break: break-word !important;
font-size: clamp(11px, 2.5vw, 14px) !important;
}
/* ── dataframe: horizontal scroll on small screens ── */
.gr-dataframe, .svelte-table-wrap, table {
overflow-x: auto !important;
display: block !important;
width: 100% !important;
font-size: clamp(11px, 2.5vw, 14px) !important;
}
/* ── confidence / time row ── */
.result-row {
display: flex;
flex-wrap: wrap;
gap: 10px;
}
.result-row > * {
flex: 1 1 140px;
min-width: 0;
}
/* ── markdown content ── */
.gr-markdown p,
.gr-markdown li {
font-size: clamp(13px, 2.8vw, 15px) !important;
line-height: 1.6 !important;
}
.gr-markdown h2 {
font-size: clamp(16px, 4vw, 22px) !important;
}
.gr-markdown h3 {
font-size: clamp(14px, 3vw, 18px) !important;
}
/* ── image components ── */
.gr-image img {
max-width: 100% !important;
height: auto !important;
}
/* ── small screens ── */
@media (max-width: 480px) {
.gradio-container {
padding: 0 8px !important;
}
.tab-nav button {
font-size: 11px !important;
padding: 6px 6px !important;
}
}
"""
# =====================================================
# SHOW/HIDE CUSTOM PROMPT
# =====================================================
def toggle_prompt(task):
if task == "Custom Prompt":
return gr.update(visible=True)
return gr.update(visible=False)
# =====================================================
# INFERENCE
# =====================================================
def generate_response(image, task, custom_prompt):
if image is None:
return (None, "Please upload an image.", "-", "-")
if task == "Custom Prompt":
prompt = custom_prompt.strip()
if len(prompt) == 0:
prompt = "Describe this image."
else:
prompt = PROMPTS[task]
start = time.time()
inputs = processor(images=image, text=prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=60,
output_scores=True,
return_dict_in_generate=True
)
generated_text = processor.batch_decode(
outputs.sequences, skip_special_tokens=True
)[0]
# ── confidence score ──
token_confidences = []
for score in outputs.scores:
probs = torch.softmax(score, dim=-1)
max_prob = probs.max().item()
token_confidences.append(max_prob)
if len(token_confidences) > 0:
confidence = (sum(token_confidences) / len(token_confidences)) * 100
else:
confidence = 0
confidence = f"{confidence:.2f}%"
inference_time = time.time() - start
return (image, generated_text, confidence, f"{inference_time:.2f} sec")
# =====================================================
# UI
# =====================================================
with gr.Blocks(
title="AI Image Captioning System",
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"),
css=css
) as demo:
gr.HTML('<div class="main-title">πŸ–ΌοΈ AI Image Captioning System</div>')
with gr.Tabs():
# ─────────────────────────────────────────
# HOME
# ─────────────────────────────────────────
with gr.Tab("🏠 Home"):
gr.Markdown("""
## Project Information
### Project Name
AI Image Captioning System
### Student
Muhammad Raudhatul
### Model
InstructBLIP FLAN-T5 XL
### Dataset
MS COCO 2017
### Description
Image caption generator using InstructBLIP model. Upload any image to get an automatic caption.
### Deployment
Hugging Face Spaces
""")
# ─────────────────────────────────────────
# GENERATE
# ─────────────────────────────────────────
with gr.Tab("πŸ–ΌοΈ Generate"):
# images side-by-side, wrap on mobile
with gr.Row(elem_classes="img-row"):
image_input = gr.Image(
sources=["upload", "webcam"],
type="pil",
label="Input Image"
)
image_output = gr.Image(label="Original Image")
task_dropdown = gr.Dropdown(
choices=[
"Generate Caption",
"Detailed Caption",
"Identify Main Objects",
"Explain Scene",
"Custom Prompt"
],
value="Generate Caption",
label="Task"
)
custom_prompt = gr.Textbox(
label="Custom Prompt",
placeholder="Enter your instruction...",
visible=False
)
task_dropdown.change(toggle_prompt, task_dropdown, custom_prompt)
generate_btn = gr.Button("Generate Caption", variant="primary")
response_output = gr.Textbox(label="Generated Caption", lines=4)
# confidence + time β€” wrap on mobile
with gr.Row(elem_classes="result-row"):
confidence_output = gr.Textbox(label="Confidence Score")
time_output = gr.Textbox(label="Inference Time")
generate_btn.click(
fn=generate_response,
inputs=[image_input, task_dropdown, custom_prompt],
outputs=[image_output, response_output, confidence_output, time_output]
)
# ─────────────────────────────────────────
# MODEL EVALUATION
# ─────────────────────────────────────────
with gr.Tab("πŸ“Š Model Evaluation"):
gr.Markdown("## Dataset Statistics")
with gr.Row(elem_classes="stats-row"):
gr.Number(value=TRAIN_IMAGES, label="Training Images")
gr.Number(value=VALID_IMAGES, label="Validation Images")
with gr.Row(elem_classes="stats-row"):
gr.Number(value=TRAIN_CAPTIONS, label="Training Captions")
gr.Number(value=VALID_CAPTIONS, label="Validation Captions")
gr.Markdown("## Model Performance")
with gr.Row(elem_classes="metrics-row"):
gr.Number(value=BLEU4, label="BLEU-4")
gr.Number(value=ROUGE_L, label="ROUGE-L")
gr.Number(value=CIDER, label="CIDEr")
gr.Dataframe(value=metrics_df, interactive=False)
demo.launch()