Image_Caption / app.py
Zaman01's picture
Update app.py
f9de6d5 verified
Raw
History Blame Contribute Delete
5.59 kB
# import gradio as gr
# from transformers import BlipProcessor, BlipForConditionalGeneration
# from PIL import Image
# import torch
# import requests
# # Load model & processor
# processor = BlipProcessor.from_pretrained(
# "Salesforce/blip-image-captioning-base"
# )
# model = BlipForConditionalGeneration.from_pretrained(
# "Salesforce/blip-image-captioning-base"
# )
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model.to(device)
# def caption_image(image, prompt="", openai_api_key=""):
# if not prompt or not prompt.strip():
# return "Please enter a prompt/question for the image."
# image = image.convert("RGB")
# # Use OpenAI API if key provided (unchanged)
# if openai_api_key:
# try:
# import base64
# from io import BytesIO
# buffered = BytesIO()
# image.save(buffered, format="PNG")
# img_b64 = base64.b64encode(buffered.getvalue()).decode()
# headers = {
# "Authorization": f"Bearer {openai_api_key}",
# "Content-Type": "application/json"
# }
# data = {
# "model": "gpt-4-vision-preview",
# "messages": [
# {
# "role": "user",
# "content": [
# {"type": "text", "text": prompt.strip()},
# {"type": "image_url", "image_url": f"data:image/png;base64,{img_b64}"}
# ]
# }
# ],
# "max_tokens": 100
# }
# resp = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=data)
# if resp.status_code == 200:
# result = resp.json()
# return result["choices"][0]["message"]["content"].strip()
# else:
# return f"OpenAI API error: {resp.status_code} {resp.text}"
# except Exception as e:
# return f"OpenAI API error: {e}"
# # BLIP: always use prompt as instruction, no retry, fast settings
# p = prompt.strip()
# prompt_text = f"Question: {p} Answer:"
# inputs = processor(images=image, text=prompt_text, return_tensors="pt").to(device)
# # Speed up: reduce beams and max_length
# gen_kwargs = {"max_length": 25, "num_beams": 1, "early_stopping": True}
# output = model.generate(**inputs, **gen_kwargs)
# caption = processor.decode(output[0], skip_special_tokens=True)
# # Extract answer after 'Answer:' if present
# idx = caption.lower().find("answer:")
# if idx != -1:
# ans = caption[idx + len("answer:"):].strip()
# if ans:
# return ans
# # Otherwise, return the raw caption
# return caption.strip()
# # Gradio UI: horizontal layout with image, prompt, button left; output right
# with gr.Blocks() as demo:
# gr.Markdown("## 🖼️ Image Captioning (Prompt-driven)\nUpload an image, enter a prompt, and click Submit. Output depends on both image and prompt.")
# with gr.Row():
# with gr.Column(scale=2):
# img = gr.Image(type="pil", label="Upload Image")
# prompt = gr.Textbox(label="Prompt (ask a question)", placeholder="What is the color of the t-shirt?")
# openai_api_key = gr.Textbox(label="OpenAI API Key (optional)", type="password", placeholder="sk-...", lines=1)
# btn = gr.Button("Submit")
# with gr.Column(scale=1):
# out = gr.Textbox(label="Answer", lines=6)
# btn.click(fn=caption_image, inputs=[img, prompt, openai_api_key], outputs=out)
# demo.launch()
import gradio as gr
import torch
from transformers import BlipProcessor, BlipForQuestionAnswering
from PIL import Image
# ---------------------------
# Load BLIP VQA model
# ---------------------------
MODEL_NAME = "Salesforce/blip-vqa-base"
processor = BlipProcessor.from_pretrained(MODEL_NAME)
model = BlipForQuestionAnswering.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
# ---------------------------
# Inference function
# ---------------------------
def answer_image_question(image, question):
if image is None:
return "Please upload an image."
if not question.strip():
return "Please enter a question."
image = image.convert("RGB")
inputs = processor(
images=image,
text=question,
return_tensors="pt"
).to(device)
with torch.no_grad():
output = model.generate(
**inputs,
max_length=10, # fast
num_beams=1 # faster
)
answer = processor.decode(output[0], skip_special_tokens=True)
return answer
# ---------------------------
# Gradio UI
# ---------------------------
with gr.Blocks() as demo:
gr.Markdown("## 🖼️ Image Question Answering (Fast & Accurate)")
gr.Markdown(
"Upload an image and ask a question like:\n"
"Anything"
)
with gr.Row():
with gr.Column():
img = gr.Image(type="pil", label="Upload Image")
question = gr.Textbox(
label="Question",
placeholder="What is the color of the shirt?"
)
btn = gr.Button("Submit")
with gr.Column():
answer = gr.Textbox(label="Answer", lines=3)
btn.click(
fn=answer_image_question,
inputs=[img, question],
outputs=answer
)
demo.launch()