Spaces:
Paused
Paused
| # import gradio as gr | |
| # from transformers import BlipProcessor, BlipForConditionalGeneration | |
| # from PIL import Image | |
| # import torch | |
| # import requests | |
| # # Load model & processor | |
| # processor = BlipProcessor.from_pretrained( | |
| # "Salesforce/blip-image-captioning-base" | |
| # ) | |
| # model = BlipForConditionalGeneration.from_pretrained( | |
| # "Salesforce/blip-image-captioning-base" | |
| # ) | |
| # device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # model.to(device) | |
| # def caption_image(image, prompt="", openai_api_key=""): | |
| # if not prompt or not prompt.strip(): | |
| # return "Please enter a prompt/question for the image." | |
| # image = image.convert("RGB") | |
| # # Use OpenAI API if key provided (unchanged) | |
| # if openai_api_key: | |
| # try: | |
| # import base64 | |
| # from io import BytesIO | |
| # buffered = BytesIO() | |
| # image.save(buffered, format="PNG") | |
| # img_b64 = base64.b64encode(buffered.getvalue()).decode() | |
| # headers = { | |
| # "Authorization": f"Bearer {openai_api_key}", | |
| # "Content-Type": "application/json" | |
| # } | |
| # data = { | |
| # "model": "gpt-4-vision-preview", | |
| # "messages": [ | |
| # { | |
| # "role": "user", | |
| # "content": [ | |
| # {"type": "text", "text": prompt.strip()}, | |
| # {"type": "image_url", "image_url": f"data:image/png;base64,{img_b64}"} | |
| # ] | |
| # } | |
| # ], | |
| # "max_tokens": 100 | |
| # } | |
| # resp = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=data) | |
| # if resp.status_code == 200: | |
| # result = resp.json() | |
| # return result["choices"][0]["message"]["content"].strip() | |
| # else: | |
| # return f"OpenAI API error: {resp.status_code} {resp.text}" | |
| # except Exception as e: | |
| # return f"OpenAI API error: {e}" | |
| # # BLIP: always use prompt as instruction, no retry, fast settings | |
| # p = prompt.strip() | |
| # prompt_text = f"Question: {p} Answer:" | |
| # inputs = processor(images=image, text=prompt_text, return_tensors="pt").to(device) | |
| # # Speed up: reduce beams and max_length | |
| # gen_kwargs = {"max_length": 25, "num_beams": 1, "early_stopping": True} | |
| # output = model.generate(**inputs, **gen_kwargs) | |
| # caption = processor.decode(output[0], skip_special_tokens=True) | |
| # # Extract answer after 'Answer:' if present | |
| # idx = caption.lower().find("answer:") | |
| # if idx != -1: | |
| # ans = caption[idx + len("answer:"):].strip() | |
| # if ans: | |
| # return ans | |
| # # Otherwise, return the raw caption | |
| # return caption.strip() | |
| # # Gradio UI: horizontal layout with image, prompt, button left; output right | |
| # with gr.Blocks() as demo: | |
| # gr.Markdown("## 🖼️ Image Captioning (Prompt-driven)\nUpload an image, enter a prompt, and click Submit. Output depends on both image and prompt.") | |
| # with gr.Row(): | |
| # with gr.Column(scale=2): | |
| # img = gr.Image(type="pil", label="Upload Image") | |
| # prompt = gr.Textbox(label="Prompt (ask a question)", placeholder="What is the color of the t-shirt?") | |
| # openai_api_key = gr.Textbox(label="OpenAI API Key (optional)", type="password", placeholder="sk-...", lines=1) | |
| # btn = gr.Button("Submit") | |
| # with gr.Column(scale=1): | |
| # out = gr.Textbox(label="Answer", lines=6) | |
| # btn.click(fn=caption_image, inputs=[img, prompt, openai_api_key], outputs=out) | |
| # demo.launch() | |
| import gradio as gr | |
| import torch | |
| from transformers import BlipProcessor, BlipForQuestionAnswering | |
| from PIL import Image | |
| # --------------------------- | |
| # Load BLIP VQA model | |
| # --------------------------- | |
| MODEL_NAME = "Salesforce/blip-vqa-base" | |
| processor = BlipProcessor.from_pretrained(MODEL_NAME) | |
| model = BlipForQuestionAnswering.from_pretrained(MODEL_NAME) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| model.eval() | |
| # --------------------------- | |
| # Inference function | |
| # --------------------------- | |
| def answer_image_question(image, question): | |
| if image is None: | |
| return "Please upload an image." | |
| if not question.strip(): | |
| return "Please enter a question." | |
| image = image.convert("RGB") | |
| inputs = processor( | |
| images=image, | |
| text=question, | |
| return_tensors="pt" | |
| ).to(device) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_length=10, # fast | |
| num_beams=1 # faster | |
| ) | |
| answer = processor.decode(output[0], skip_special_tokens=True) | |
| return answer | |
| # --------------------------- | |
| # Gradio UI | |
| # --------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🖼️ Image Question Answering (Fast & Accurate)") | |
| gr.Markdown( | |
| "Upload an image and ask a question like:\n" | |
| "Anything" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| img = gr.Image(type="pil", label="Upload Image") | |
| question = gr.Textbox( | |
| label="Question", | |
| placeholder="What is the color of the shirt?" | |
| ) | |
| btn = gr.Button("Submit") | |
| with gr.Column(): | |
| answer = gr.Textbox(label="Answer", lines=3) | |
| btn.click( | |
| fn=answer_image_question, | |
| inputs=[img, question], | |
| outputs=answer | |
| ) | |
| demo.launch() | |