Spaces:
Sleeping
Sleeping
| import time | |
| import torch | |
| import pandas as pd | |
| import gradio as gr | |
| from transformers import ( | |
| InstructBlipProcessor, | |
| InstructBlipForConditionalGeneration | |
| ) | |
| # ===================================================== | |
| # MODEL | |
| # ===================================================== | |
| MODEL_ID = "MRaudhatul/instructblip-coco-captioning" | |
| print("Loading processor...") | |
| processor = InstructBlipProcessor.from_pretrained(MODEL_ID) | |
| print("Loading model...") | |
| model = InstructBlipForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float32 | |
| ) | |
| model.eval() | |
| print("Model loaded successfully") | |
| # ===================================================== | |
| # DATASET STATS | |
| # ===================================================== | |
| TRAIN_IMAGES = 26613 | |
| VALID_IMAGES = 2958 | |
| TRAIN_CAPTIONS = 133142 | |
| VALID_CAPTIONS = 14794 | |
| # ===================================================== | |
| # METRICS | |
| # ===================================================== | |
| BLEU1 = 0.7798 | |
| BLEU2 = 0.6066 | |
| BLEU3 = 0.4547 | |
| BLEU4 = 0.3290 | |
| ROUGE_L = 0.5909 | |
| METEOR = 0.5790 | |
| CIDER = 0.9931 | |
| metrics_df = pd.DataFrame({ | |
| "Metric": ["BLEU-1","BLEU-2","BLEU-3","BLEU-4","ROUGE-L","METEOR","CIDEr"], | |
| "Score": [BLEU1, BLEU2, BLEU3, BLEU4, ROUGE_L, METEOR, CIDER] | |
| }) | |
| # ===================================================== | |
| # TASK PROMPTS | |
| # ===================================================== | |
| PROMPTS = { | |
| "Generate Caption": "Describe this image.", | |
| "Detailed Caption": "Describe this image in detail.", | |
| "Identify Main Objects": "What are the main objects in this image?", | |
| "Explain Scene": "Explain what is happening in this image." | |
| } | |
| # ===================================================== | |
| # CSS β fully responsive, mobile-first | |
| # ===================================================== | |
| css = """ | |
| /* ββ reset & base ββ */ | |
| *, *::before, *::after { box-sizing: border-box; } | |
| html, body { | |
| width: 100% !important; | |
| overflow-x: hidden !important; | |
| } | |
| .gradio-container { | |
| min-width: unset !important; | |
| width: 100% !important; | |
| max-width: 960px !important; | |
| margin: 0 auto !important; | |
| padding: 0 12px !important; | |
| } | |
| /* paksa semua elemen tidak overflow */ | |
| .block, .form, .panel { | |
| min-width: unset !important; | |
| width: 100% !important; | |
| } | |
| footer { display: none !important; } | |
| /* ββ title block ββ */ | |
| .main-title { | |
| text-align: center; | |
| font-size: clamp(22px, 5vw, 38px); | |
| font-weight: 700; | |
| margin: 16px 0 6px; | |
| line-height: 1.2; | |
| word-break: break-word; | |
| } | |
| .subtitle { | |
| text-align: center; | |
| font-size: clamp(13px, 3vw, 16px); | |
| color: #666; | |
| margin-bottom: 20px; | |
| } | |
| /* ββ tab labels ββ */ | |
| .tab-nav button { | |
| font-size: clamp(12px, 2.5vw, 15px) !important; | |
| padding: 8px 10px !important; | |
| } | |
| /* ββ generate tab: stack on mobile, side-by-side on desktop ββ */ | |
| .img-row { | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 12px; | |
| } | |
| .img-row > * { | |
| flex: 1 1 280px; | |
| min-width: 0; | |
| } | |
| /* ββ stats row ββ */ | |
| .stats-row { | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 10px; | |
| } | |
| .stats-row > * { | |
| flex: 1 1 130px; | |
| min-width: 0; | |
| } | |
| /* ββ metrics row ββ */ | |
| .metrics-row { | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 10px; | |
| margin-bottom: 12px; | |
| } | |
| .metrics-row > * { | |
| flex: 1 1 100px; | |
| min-width: 0; | |
| } | |
| /* ββ inputs & buttons ββ */ | |
| button.primary { | |
| width: 100% !important; | |
| font-size: clamp(14px, 3vw, 16px) !important; | |
| padding: 10px 16px !important; | |
| } | |
| /* ββ labels: prevent overflow ββ */ | |
| label span { | |
| white-space: normal !important; | |
| word-break: break-word !important; | |
| font-size: clamp(11px, 2.5vw, 14px) !important; | |
| } | |
| /* ββ dataframe: horizontal scroll on small screens ββ */ | |
| .gr-dataframe, .svelte-table-wrap, table { | |
| overflow-x: auto !important; | |
| display: block !important; | |
| width: 100% !important; | |
| font-size: clamp(11px, 2.5vw, 14px) !important; | |
| } | |
| /* ββ confidence / time row ββ */ | |
| .result-row { | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 10px; | |
| } | |
| .result-row > * { | |
| flex: 1 1 140px; | |
| min-width: 0; | |
| } | |
| /* ββ markdown content ββ */ | |
| .gr-markdown p, | |
| .gr-markdown li { | |
| font-size: clamp(13px, 2.8vw, 15px) !important; | |
| line-height: 1.6 !important; | |
| } | |
| .gr-markdown h2 { | |
| font-size: clamp(16px, 4vw, 22px) !important; | |
| } | |
| .gr-markdown h3 { | |
| font-size: clamp(14px, 3vw, 18px) !important; | |
| } | |
| /* ββ image components ββ */ | |
| .gr-image img { | |
| max-width: 100% !important; | |
| height: auto !important; | |
| } | |
| /* ββ small screens ββ */ | |
| @media (max-width: 480px) { | |
| .gradio-container { | |
| padding: 0 8px !important; | |
| } | |
| .tab-nav button { | |
| font-size: 11px !important; | |
| padding: 6px 6px !important; | |
| } | |
| } | |
| """ | |
| # ===================================================== | |
| # SHOW/HIDE CUSTOM PROMPT | |
| # ===================================================== | |
| def toggle_prompt(task): | |
| if task == "Custom Prompt": | |
| return gr.update(visible=True) | |
| return gr.update(visible=False) | |
| # ===================================================== | |
| # INFERENCE | |
| # ===================================================== | |
| def generate_response(image, task, custom_prompt): | |
| if image is None: | |
| return (None, "Please upload an image.", "-", "-") | |
| if task == "Custom Prompt": | |
| prompt = custom_prompt.strip() | |
| if len(prompt) == 0: | |
| prompt = "Describe this image." | |
| else: | |
| prompt = PROMPTS[task] | |
| start = time.time() | |
| inputs = processor(images=image, text=prompt, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=60, | |
| output_scores=True, | |
| return_dict_in_generate=True | |
| ) | |
| generated_text = processor.batch_decode( | |
| outputs.sequences, skip_special_tokens=True | |
| )[0] | |
| # ββ confidence score ββ | |
| token_confidences = [] | |
| for score in outputs.scores: | |
| probs = torch.softmax(score, dim=-1) | |
| max_prob = probs.max().item() | |
| token_confidences.append(max_prob) | |
| if len(token_confidences) > 0: | |
| confidence = (sum(token_confidences) / len(token_confidences)) * 100 | |
| else: | |
| confidence = 0 | |
| confidence = f"{confidence:.2f}%" | |
| inference_time = time.time() - start | |
| return (image, generated_text, confidence, f"{inference_time:.2f} sec") | |
| # ===================================================== | |
| # UI | |
| # ===================================================== | |
| with gr.Blocks( | |
| title="AI Image Captioning System", | |
| theme=gr.themes.Soft(primary_hue="blue", secondary_hue="sky"), | |
| css=css | |
| ) as demo: | |
| gr.HTML('<div class="main-title">πΌοΈ AI Image Captioning System</div>') | |
| with gr.Tabs(): | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| # HOME | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Home"): | |
| gr.Markdown(""" | |
| ## Project Information | |
| ### Project Name | |
| AI Image Captioning System | |
| ### Student | |
| Muhammad Raudhatul | |
| ### Model | |
| InstructBLIP FLAN-T5 XL | |
| ### Dataset | |
| MS COCO 2017 | |
| ### Description | |
| Image caption generator using InstructBLIP model. Upload any image to get an automatic caption. | |
| ### Deployment | |
| Hugging Face Spaces | |
| """) | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| # GENERATE | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("πΌοΈ Generate"): | |
| # images side-by-side, wrap on mobile | |
| with gr.Row(elem_classes="img-row"): | |
| image_input = gr.Image( | |
| sources=["upload", "webcam"], | |
| type="pil", | |
| label="Input Image" | |
| ) | |
| image_output = gr.Image(label="Original Image") | |
| task_dropdown = gr.Dropdown( | |
| choices=[ | |
| "Generate Caption", | |
| "Detailed Caption", | |
| "Identify Main Objects", | |
| "Explain Scene", | |
| "Custom Prompt" | |
| ], | |
| value="Generate Caption", | |
| label="Task" | |
| ) | |
| custom_prompt = gr.Textbox( | |
| label="Custom Prompt", | |
| placeholder="Enter your instruction...", | |
| visible=False | |
| ) | |
| task_dropdown.change(toggle_prompt, task_dropdown, custom_prompt) | |
| generate_btn = gr.Button("Generate Caption", variant="primary") | |
| response_output = gr.Textbox(label="Generated Caption", lines=4) | |
| # confidence + time β wrap on mobile | |
| with gr.Row(elem_classes="result-row"): | |
| confidence_output = gr.Textbox(label="Confidence Score") | |
| time_output = gr.Textbox(label="Inference Time") | |
| generate_btn.click( | |
| fn=generate_response, | |
| inputs=[image_input, task_dropdown, custom_prompt], | |
| outputs=[image_output, response_output, confidence_output, time_output] | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| # MODEL EVALUATION | |
| # βββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Tab("π Model Evaluation"): | |
| gr.Markdown("## Dataset Statistics") | |
| with gr.Row(elem_classes="stats-row"): | |
| gr.Number(value=TRAIN_IMAGES, label="Training Images") | |
| gr.Number(value=VALID_IMAGES, label="Validation Images") | |
| with gr.Row(elem_classes="stats-row"): | |
| gr.Number(value=TRAIN_CAPTIONS, label="Training Captions") | |
| gr.Number(value=VALID_CAPTIONS, label="Validation Captions") | |
| gr.Markdown("## Model Performance") | |
| with gr.Row(elem_classes="metrics-row"): | |
| gr.Number(value=BLEU4, label="BLEU-4") | |
| gr.Number(value=ROUGE_L, label="ROUGE-L") | |
| gr.Number(value=CIDER, label="CIDEr") | |
| gr.Dataframe(value=metrics_df, interactive=False) | |
| demo.launch() |