Jaswanth0217's picture
added app file
76772c3
import gradio as gr
import torch
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import time
# Load BLIP model for initial caption generation
print("Loading BLIP model...")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
# Load GPT-2 model for caption refinement
print("Loading GPT-2 model...")
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2")
# Move models to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
blip_model.to(device)
gpt2_model.to(device)
print(f"Using device: {device}")
def generate_caption(image):
"""Generate caption using BLIP and refine it with GPT-2"""
start_time = time.time()
# Process the image for BLIP
inputs = blip_processor(image, return_tensors="pt").to(device)
# Generate caption with BLIP
with torch.no_grad():
generated_ids = blip_model.generate(**inputs, max_length=30)
blip_caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
# Prepare prompt for GPT-2 refinement
prompt = f"An image shows {blip_caption}. A detailed description would be: "
# Tokenize and generate refined caption with GPT-2
input_ids = gpt2_tokenizer.encode(prompt, return_tensors='pt').to(device)
with torch.no_grad():
output = gpt2_model.generate(
input_ids,
max_length=100,
num_return_sequences=1,
temperature=0.7,
top_k=50,
top_p=0.95,
no_repeat_ngram_size=2
)
refined_caption = gpt2_tokenizer.decode(output[0], skip_special_tokens=True)
# Extract just the enhanced part (after the prompt)
enhanced_caption = refined_caption.split("A detailed description would be:")[-1].strip()
# Calculate processing time
processing_time = time.time() - start_time
return blip_caption, enhanced_caption, f"Processing time: {processing_time:.2f} seconds"
# Create Gradio interface
with gr.Blocks(title="AI Image Captioning App", theme=gr.themes.Soft()) as demo:
gr.Markdown("# AI-Powered Image Captioning")
gr.Markdown("This app generates captions for your images using BLIP for initial caption generation and GPT-2 for enhancement.")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload an image")
caption_button = gr.Button("Generate Captions", variant="primary")
with gr.Column(scale=2):
with gr.Group():
blip_caption_output = gr.Textbox(label="BLIP Caption (Base)")
enhanced_caption_output = gr.Textbox(label="GPT-2 Enhanced Caption")
processing_time = gr.Textbox(label="Performance")
caption_button.click(
generate_caption,
inputs=[image_input],
outputs=[blip_caption_output, enhanced_caption_output, processing_time]
)
gr.Markdown("## How it works")
gr.Markdown("""
1. **BLIP Model**: Extracts visual features and generates a basic caption
2. **GPT-2 Model**: Enhances the basic caption to make it more detailed and descriptive
3. **Both captions are displayed** so you can see the improvement from the enhancement process
""")
# Launch the app
if __name__ == "__main__":
demo.launch()