import gradio as gr import torch from PIL import Image from transformers import BlipProcessor, BlipForConditionalGeneration from transformers import GPT2Tokenizer, GPT2LMHeadModel import time # Load BLIP model for initial caption generation print("Loading BLIP model...") blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") # Load GPT-2 model for caption refinement print("Loading GPT-2 model...") gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2") # Move models to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") blip_model.to(device) gpt2_model.to(device) print(f"Using device: {device}") def generate_caption(image): """Generate caption using BLIP and refine it with GPT-2""" start_time = time.time() # Process the image for BLIP inputs = blip_processor(image, return_tensors="pt").to(device) # Generate caption with BLIP with torch.no_grad(): generated_ids = blip_model.generate(**inputs, max_length=30) blip_caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True) # Prepare prompt for GPT-2 refinement prompt = f"An image shows {blip_caption}. A detailed description would be: " # Tokenize and generate refined caption with GPT-2 input_ids = gpt2_tokenizer.encode(prompt, return_tensors='pt').to(device) with torch.no_grad(): output = gpt2_model.generate( input_ids, max_length=100, num_return_sequences=1, temperature=0.7, top_k=50, top_p=0.95, no_repeat_ngram_size=2 ) refined_caption = gpt2_tokenizer.decode(output[0], skip_special_tokens=True) # Extract just the enhanced part (after the prompt) enhanced_caption = refined_caption.split("A detailed description would be:")[-1].strip() # Calculate processing time processing_time = time.time() - start_time return blip_caption, enhanced_caption, f"Processing time: {processing_time:.2f} seconds" # Create Gradio interface with gr.Blocks(title="AI Image Captioning App", theme=gr.themes.Soft()) as demo: gr.Markdown("# AI-Powered Image Captioning") gr.Markdown("This app generates captions for your images using BLIP for initial caption generation and GPT-2 for enhancement.") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload an image") caption_button = gr.Button("Generate Captions", variant="primary") with gr.Column(scale=2): with gr.Group(): blip_caption_output = gr.Textbox(label="BLIP Caption (Base)") enhanced_caption_output = gr.Textbox(label="GPT-2 Enhanced Caption") processing_time = gr.Textbox(label="Performance") caption_button.click( generate_caption, inputs=[image_input], outputs=[blip_caption_output, enhanced_caption_output, processing_time] ) gr.Markdown("## How it works") gr.Markdown(""" 1. **BLIP Model**: Extracts visual features and generates a basic caption 2. **GPT-2 Model**: Enhances the basic caption to make it more detailed and descriptive 3. **Both captions are displayed** so you can see the improvement from the enhancement process """) # Launch the app if __name__ == "__main__": demo.launch()