HW_3 / app.py
kkkai123456's picture
Create app.py
0263be8 verified
import gradio as gr
import torch
from PIL import Image
from transformers import (
BlipProcessor, BlipForConditionalGeneration,
BlipForQuestionAnswering,
CLIPProcessor, CLIPModel
)
import numpy as np
# ==================== Model Loading ====================
print("πŸ”„ Loading models...")
# BLIP Image Captioning Model
caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# BLIP Visual Question Answering Model
vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
vqa_model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
# CLIP Image Classification Model
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
print("βœ… Models loaded successfully!")
# ==================== Function Definitions ====================
def generate_caption(image):
"""Generate image caption"""
if image is None:
return "❌ Please upload an image first"
try:
# Process image
inputs = caption_processor(image, return_tensors="pt")
# Generate caption
out = caption_model.generate(**inputs, max_length=50)
caption = caption_processor.decode(out[0], skip_special_tokens=True)
return f"πŸ“ Image Caption:\n{caption}"
except Exception as e:
return f"❌ Processing failed: {str(e)}"
def answer_question(image, question):
"""Visual Question Answering"""
if image is None:
return "❌ Please upload an image first"
if not question.strip():
return "❌ Please enter a question"
try:
# Process inputs
inputs = vqa_processor(image, question, return_tensors="pt")
# Generate answer
out = vqa_model.generate(**inputs, max_length=20)
answer = vqa_processor.decode(out[0], skip_special_tokens=True)
return f"❓ Question: {question}\n\nβœ… Answer: {answer}"
except Exception as e:
return f"❌ Processing failed: {str(e)}"
def classify_image(image, categories):
"""Zero-shot Image Classification"""
if image is None:
return "❌ Please upload an image first"
if not categories.strip():
return "❌ Please enter categories"
try:
# Parse categories
category_list = [cat.strip() for cat in categories.split(",")]
# Process image and text
inputs = clip_processor(
text=category_list,
images=image,
return_tensors="pt",
padding=True
)
# Calculate similarity
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)[0]
# Format results
results = "🎯 Classification Results:\n\n"
for category, prob in zip(category_list, probs):
percentage = prob.item() * 100
bar = "β–ˆ" * int(percentage / 5)
results += f"{category}: {percentage:.2f}% {bar}\n"
return results
except Exception as e:
return f"❌ Processing failed: {str(e)}"
def multimodal_chat(image, message, history):
"""Multimodal Chat (Simplified)"""
if image is None:
return history + [[message, "❌ Please upload an image first to start chatting"]]
try:
# Use VQA model to process question
inputs = vqa_processor(image, message, return_tensors="pt")
out = vqa_model.generate(**inputs, max_length=30)
response = vqa_processor.decode(out[0], skip_special_tokens=True)
history.append([message, response])
return history
except Exception as e:
history.append([message, f"❌ Processing failed: {str(e)}"])
return history
# ==================== Gradio Interface ====================
# Custom CSS
custom_css = """
#title {
text-align: center;
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-size: 3em;
font-weight: bold;
margin-bottom: 10px;
}
#subtitle {
text-align: center;
color: #666;
font-size: 1.2em;
margin-bottom: 30px;
}
.feature-box {
border: 2px solid #667eea;
border-radius: 10px;
padding: 20px;
margin: 10px 0;
}
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
# Title
gr.HTML('<h1 id="title">πŸ€– Vision Language AI Demo</h1>')
gr.HTML('<p id="subtitle">Interactive application showcasing multiple vision-language AI capabilities</p>')
# Tabbed Interface
with gr.Tabs():
# Tab 1: Image Captioning
with gr.Tab("πŸ–ΌοΈ Image Captioning"):
gr.Markdown("### Upload an image and AI will generate a description")
with gr.Row():
with gr.Column():
caption_image = gr.Image(type="pil", label="Upload Image")
caption_btn = gr.Button("🎨 Generate Caption", variant="primary")
with gr.Column():
caption_output = gr.Textbox(
label="Generated Caption",
lines=5,
placeholder="Caption will appear here..."
)
# Examples
gr.Examples(
examples=[
["https://images.unsplash.com/photo-1514888286974-6c03e2ca1dba"],
["https://images.unsplash.com/photo-1506748686214-e9df14d4d9d0"],
],
inputs=caption_image,
label="πŸ“Έ Click on examples to try"
)
caption_btn.click(
fn=generate_caption,
inputs=caption_image,
outputs=caption_output
)
caption_image.change(
fn=generate_caption,
inputs=caption_image,
outputs=caption_output
)
# Tab 2: Visual Question Answering
with gr.Tab("πŸ” Visual Question Answering"):
gr.Markdown("### Upload an image and ask questions, AI will answer based on the image content")
with gr.Row():
with gr.Column():
vqa_image = gr.Image(type="pil", label="Upload Image")
vqa_question = gr.Textbox(
label="Enter Question",
placeholder="e.g., What color is the car? How many people are there?",
lines=2
)
vqa_btn = gr.Button("πŸ€” Get Answer", variant="primary")
with gr.Column():
vqa_output = gr.Textbox(
label="AI Answer",
lines=6,
placeholder="Answer will appear here..."
)
# Common question examples
gr.Markdown("**πŸ’‘ Common Question Examples:**")
gr.Markdown("- What is in the image?\n- What color is...?\n- How many ... are there?\n- Is there a ... in the image?")
vqa_btn.click(
fn=answer_question,
inputs=[vqa_image, vqa_question],
outputs=vqa_output
)
# Tab 3: Image Classification
with gr.Tab("🏷️ Zero-Shot Classification"):
gr.Markdown("### Define custom categories and AI will classify the image")
with gr.Row():
with gr.Column():
classify_image_input = gr.Image(type="pil", label="Upload Image")
classify_categories = gr.Textbox(
label="Categories (comma-separated)",
placeholder="e.g., cat, dog, bird, car, building",
value="cat, dog, bird, car, building",
lines=2
)
classify_btn = gr.Button("🎯 Classify", variant="primary")
with gr.Column():
classify_output = gr.Textbox(
label="Classification Results",
lines=8,
placeholder="Results will appear here..."
)
gr.Markdown("**πŸ’‘ Tip:** You can input any categories, the model will calculate similarity between the image and each category")
classify_btn.click(
fn=classify_image,
inputs=[classify_image_input, classify_categories],
outputs=classify_output
)
# Tab 4: Multimodal Chat
with gr.Tab("πŸ’¬ Multimodal Chat"):
gr.Markdown("### Upload an image and have a conversation with AI about it")
with gr.Row():
with gr.Column(scale=1):
chat_image = gr.Image(type="pil", label="Upload Image")
gr.Markdown("**πŸ’‘ Conversation Prompts:**")
gr.Markdown("- Describe this image\n- What's in the image?\n- Where is this?\n- What is the main color?")
with gr.Column(scale=2):
chatbot = gr.Chatbot(label="Chat History", height=400)
chat_input = gr.Textbox(
label="Enter Message",
placeholder="Type your question...",
lines=2
)
with gr.Row():
chat_btn = gr.Button("πŸ“€ Send", variant="primary")
clear_btn = gr.Button("πŸ—‘οΈ Clear Chat")
chat_btn.click(
fn=multimodal_chat,
inputs=[chat_image, chat_input, chatbot],
outputs=chatbot
)
chat_input.submit(
fn=multimodal_chat,
inputs=[chat_image, chat_input, chatbot],
outputs=chatbot
)
clear_btn.click(lambda: [], outputs=chatbot)
# Footer
gr.Markdown("---")
gr.Markdown("""
### πŸ“š About This Application
- **Models**: BLIP (Captioning & VQA) + CLIP (Classification)
- **Framework**: Gradio + Transformers
- **Deployment**: Can be deployed to Hugging Face Spaces
- **Open Source**: All models are open source
⚑ **Performance Tip**: Use Hugging Face Spaces Zero GPU for significantly faster processing
""")
# Launch application
if __name__ == "__main__":
demo.launch(share=True)