Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from transformers import ( | |
| BlipProcessor, BlipForConditionalGeneration, | |
| BlipForQuestionAnswering, | |
| CLIPProcessor, CLIPModel | |
| ) | |
| import numpy as np | |
| # ==================== Model Loading ==================== | |
| print("π Loading models...") | |
| # BLIP Image Captioning Model | |
| caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
| # BLIP Visual Question Answering Model | |
| vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") | |
| # CLIP Image Classification Model | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| print("β Models loaded successfully!") | |
| # ==================== Function Definitions ==================== | |
| def generate_caption(image): | |
| """Generate image caption""" | |
| if image is None: | |
| return "β Please upload an image first" | |
| try: | |
| # Process image | |
| inputs = caption_processor(image, return_tensors="pt") | |
| # Generate caption | |
| out = caption_model.generate(**inputs, max_length=50) | |
| caption = caption_processor.decode(out[0], skip_special_tokens=True) | |
| return f"π Image Caption:\n{caption}" | |
| except Exception as e: | |
| return f"β Processing failed: {str(e)}" | |
| def answer_question(image, question): | |
| """Visual Question Answering""" | |
| if image is None: | |
| return "β Please upload an image first" | |
| if not question.strip(): | |
| return "β Please enter a question" | |
| try: | |
| # Process inputs | |
| inputs = vqa_processor(image, question, return_tensors="pt") | |
| # Generate answer | |
| out = vqa_model.generate(**inputs, max_length=20) | |
| answer = vqa_processor.decode(out[0], skip_special_tokens=True) | |
| return f"β Question: {question}\n\nβ Answer: {answer}" | |
| except Exception as e: | |
| return f"β Processing failed: {str(e)}" | |
| def classify_image(image, categories): | |
| """Zero-shot Image Classification""" | |
| if image is None: | |
| return "β Please upload an image first" | |
| if not categories.strip(): | |
| return "β Please enter categories" | |
| try: | |
| # Parse categories | |
| category_list = [cat.strip() for cat in categories.split(",")] | |
| # Process image and text | |
| inputs = clip_processor( | |
| text=category_list, | |
| images=image, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| # Calculate similarity | |
| outputs = clip_model(**inputs) | |
| logits_per_image = outputs.logits_per_image | |
| probs = logits_per_image.softmax(dim=1)[0] | |
| # Format results | |
| results = "π― Classification Results:\n\n" | |
| for category, prob in zip(category_list, probs): | |
| percentage = prob.item() * 100 | |
| bar = "β" * int(percentage / 5) | |
| results += f"{category}: {percentage:.2f}% {bar}\n" | |
| return results | |
| except Exception as e: | |
| return f"β Processing failed: {str(e)}" | |
| def multimodal_chat(image, message, history): | |
| """Multimodal Chat (Simplified)""" | |
| if image is None: | |
| return history + [[message, "β Please upload an image first to start chatting"]] | |
| try: | |
| # Use VQA model to process question | |
| inputs = vqa_processor(image, message, return_tensors="pt") | |
| out = vqa_model.generate(**inputs, max_length=30) | |
| response = vqa_processor.decode(out[0], skip_special_tokens=True) | |
| history.append([message, response]) | |
| return history | |
| except Exception as e: | |
| history.append([message, f"β Processing failed: {str(e)}"]) | |
| return history | |
| # ==================== Gradio Interface ==================== | |
| # Custom CSS | |
| custom_css = """ | |
| #title { | |
| text-align: center; | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| font-size: 3em; | |
| font-weight: bold; | |
| margin-bottom: 10px; | |
| } | |
| #subtitle { | |
| text-align: center; | |
| color: #666; | |
| font-size: 1.2em; | |
| margin-bottom: 30px; | |
| } | |
| .feature-box { | |
| border: 2px solid #667eea; | |
| border-radius: 10px; | |
| padding: 20px; | |
| margin: 10px 0; | |
| } | |
| """ | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
| # Title | |
| gr.HTML('<h1 id="title">π€ Vision Language AI Demo</h1>') | |
| gr.HTML('<p id="subtitle">Interactive application showcasing multiple vision-language AI capabilities</p>') | |
| # Tabbed Interface | |
| with gr.Tabs(): | |
| # Tab 1: Image Captioning | |
| with gr.Tab("πΌοΈ Image Captioning"): | |
| gr.Markdown("### Upload an image and AI will generate a description") | |
| with gr.Row(): | |
| with gr.Column(): | |
| caption_image = gr.Image(type="pil", label="Upload Image") | |
| caption_btn = gr.Button("π¨ Generate Caption", variant="primary") | |
| with gr.Column(): | |
| caption_output = gr.Textbox( | |
| label="Generated Caption", | |
| lines=5, | |
| placeholder="Caption will appear here..." | |
| ) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["https://images.unsplash.com/photo-1514888286974-6c03e2ca1dba"], | |
| ["https://images.unsplash.com/photo-1506748686214-e9df14d4d9d0"], | |
| ], | |
| inputs=caption_image, | |
| label="πΈ Click on examples to try" | |
| ) | |
| caption_btn.click( | |
| fn=generate_caption, | |
| inputs=caption_image, | |
| outputs=caption_output | |
| ) | |
| caption_image.change( | |
| fn=generate_caption, | |
| inputs=caption_image, | |
| outputs=caption_output | |
| ) | |
| # Tab 2: Visual Question Answering | |
| with gr.Tab("π Visual Question Answering"): | |
| gr.Markdown("### Upload an image and ask questions, AI will answer based on the image content") | |
| with gr.Row(): | |
| with gr.Column(): | |
| vqa_image = gr.Image(type="pil", label="Upload Image") | |
| vqa_question = gr.Textbox( | |
| label="Enter Question", | |
| placeholder="e.g., What color is the car? How many people are there?", | |
| lines=2 | |
| ) | |
| vqa_btn = gr.Button("π€ Get Answer", variant="primary") | |
| with gr.Column(): | |
| vqa_output = gr.Textbox( | |
| label="AI Answer", | |
| lines=6, | |
| placeholder="Answer will appear here..." | |
| ) | |
| # Common question examples | |
| gr.Markdown("**π‘ Common Question Examples:**") | |
| gr.Markdown("- What is in the image?\n- What color is...?\n- How many ... are there?\n- Is there a ... in the image?") | |
| vqa_btn.click( | |
| fn=answer_question, | |
| inputs=[vqa_image, vqa_question], | |
| outputs=vqa_output | |
| ) | |
| # Tab 3: Image Classification | |
| with gr.Tab("π·οΈ Zero-Shot Classification"): | |
| gr.Markdown("### Define custom categories and AI will classify the image") | |
| with gr.Row(): | |
| with gr.Column(): | |
| classify_image_input = gr.Image(type="pil", label="Upload Image") | |
| classify_categories = gr.Textbox( | |
| label="Categories (comma-separated)", | |
| placeholder="e.g., cat, dog, bird, car, building", | |
| value="cat, dog, bird, car, building", | |
| lines=2 | |
| ) | |
| classify_btn = gr.Button("π― Classify", variant="primary") | |
| with gr.Column(): | |
| classify_output = gr.Textbox( | |
| label="Classification Results", | |
| lines=8, | |
| placeholder="Results will appear here..." | |
| ) | |
| gr.Markdown("**π‘ Tip:** You can input any categories, the model will calculate similarity between the image and each category") | |
| classify_btn.click( | |
| fn=classify_image, | |
| inputs=[classify_image_input, classify_categories], | |
| outputs=classify_output | |
| ) | |
| # Tab 4: Multimodal Chat | |
| with gr.Tab("π¬ Multimodal Chat"): | |
| gr.Markdown("### Upload an image and have a conversation with AI about it") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| chat_image = gr.Image(type="pil", label="Upload Image") | |
| gr.Markdown("**π‘ Conversation Prompts:**") | |
| gr.Markdown("- Describe this image\n- What's in the image?\n- Where is this?\n- What is the main color?") | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot(label="Chat History", height=400) | |
| chat_input = gr.Textbox( | |
| label="Enter Message", | |
| placeholder="Type your question...", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| chat_btn = gr.Button("π€ Send", variant="primary") | |
| clear_btn = gr.Button("ποΈ Clear Chat") | |
| chat_btn.click( | |
| fn=multimodal_chat, | |
| inputs=[chat_image, chat_input, chatbot], | |
| outputs=chatbot | |
| ) | |
| chat_input.submit( | |
| fn=multimodal_chat, | |
| inputs=[chat_image, chat_input, chatbot], | |
| outputs=chatbot | |
| ) | |
| clear_btn.click(lambda: [], outputs=chatbot) | |
| # Footer | |
| gr.Markdown("---") | |
| gr.Markdown(""" | |
| ### π About This Application | |
| - **Models**: BLIP (Captioning & VQA) + CLIP (Classification) | |
| - **Framework**: Gradio + Transformers | |
| - **Deployment**: Can be deployed to Hugging Face Spaces | |
| - **Open Source**: All models are open source | |
| β‘ **Performance Tip**: Use Hugging Face Spaces Zero GPU for significantly faster processing | |
| """) | |
| # Launch application | |
| if __name__ == "__main__": | |
| demo.launch(share=True) |