File size: 3,608 Bytes
76772c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import torch
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import time

# Load BLIP model for initial caption generation
print("Loading BLIP model...")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

# Load GPT-2 model for caption refinement
print("Loading GPT-2 model...")
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2LMHeadModel.from_pretrained("gpt2")

# Move models to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
blip_model.to(device)
gpt2_model.to(device)
print(f"Using device: {device}")

def generate_caption(image):
    """Generate caption using BLIP and refine it with GPT-2"""
    start_time = time.time()
    
    # Process the image for BLIP
    inputs = blip_processor(image, return_tensors="pt").to(device)
    
    # Generate caption with BLIP
    with torch.no_grad():
        generated_ids = blip_model.generate(**inputs, max_length=30)
        blip_caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
    
    # Prepare prompt for GPT-2 refinement
    prompt = f"An image shows {blip_caption}. A detailed description would be: "
    
    # Tokenize and generate refined caption with GPT-2
    input_ids = gpt2_tokenizer.encode(prompt, return_tensors='pt').to(device)
    
    with torch.no_grad():
        output = gpt2_model.generate(
            input_ids,
            max_length=100,
            num_return_sequences=1,
            temperature=0.7,
            top_k=50,
            top_p=0.95,
            no_repeat_ngram_size=2
        )
    
    refined_caption = gpt2_tokenizer.decode(output[0], skip_special_tokens=True)
    
    # Extract just the enhanced part (after the prompt)
    enhanced_caption = refined_caption.split("A detailed description would be:")[-1].strip()
    
    # Calculate processing time
    processing_time = time.time() - start_time
    
    return blip_caption, enhanced_caption, f"Processing time: {processing_time:.2f} seconds"

# Create Gradio interface
with gr.Blocks(title="AI Image Captioning App", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# AI-Powered Image Captioning")
    gr.Markdown("This app generates captions for your images using BLIP for initial caption generation and GPT-2 for enhancement.")
    
    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil", label="Upload an image")
            caption_button = gr.Button("Generate Captions", variant="primary")
        
        with gr.Column(scale=2):
            with gr.Group():
                blip_caption_output = gr.Textbox(label="BLIP Caption (Base)")
                enhanced_caption_output = gr.Textbox(label="GPT-2 Enhanced Caption")
                processing_time = gr.Textbox(label="Performance")
    
    caption_button.click(
        generate_caption,
        inputs=[image_input],
        outputs=[blip_caption_output, enhanced_caption_output, processing_time]
    )
    
    gr.Markdown("## How it works")
    gr.Markdown("""
    1. **BLIP Model**: Extracts visual features and generates a basic caption
    2. **GPT-2 Model**: Enhances the basic caption to make it more detailed and descriptive
    3. **Both captions are displayed** so you can see the improvement from the enhancement process
    """)

# Launch the app
if __name__ == "__main__":
    demo.launch()