File size: 18,107 Bytes
6ec3614
 
 
fdc7fa3
 
 
6ec3614
fdc7fa3
 
 
6ec3614
fdc7fa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ec3614
fdc7fa3
 
 
 
6ec3614
fdc7fa3
 
2e6a234
fdc7fa3
 
 
 
 
6ec3614
fdc7fa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ec3614
 
 
 
fdc7fa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ec3614
fdc7fa3
 
 
 
2e6a234
fdc7fa3
2cbccec
fdc7fa3
 
 
 
 
 
 
 
2cbccec
2e6a234
fdc7fa3
 
 
 
2cbccec
2e6a234
 
 
fdc7fa3
 
90a72f2
2e6a234
 
 
fdc7fa3
 
2e6a234
 
fdc7fa3
 
 
 
2e6a234
 
fdc7fa3
2e6a234
fdc7fa3
 
2e6a234
fdc7fa3
6ec3614
 
 
 
 
fdc7fa3
 
 
 
 
 
2e6a234
fdc7fa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ec3614
 
 
 
2e6a234
6ec3614
 
fdc7fa3
 
 
6ec3614
 
 
 
fdc7fa3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ec3614
 
fdc7fa3
6ec3614
fdc7fa3
 
6ec3614
fdc7fa3
 
 
 
 
 
 
 
 
 
6ec3614
 
 
 
 
 
fdc7fa3
 
 
 
 
 
 
 
6ec3614
 
fdc7fa3
 
 
6ec3614
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
import gradio as gr
import torch
from PIL import Image
import torch.nn as nn
import torchvision.transforms.functional as TVF
from transformers import AutoModel, AutoProcessor, AutoTokenizer, AutoModelForConditionalGeneration, PreTrainedTokenizer, PreTrainedTokenizerFast

# Define constants
TITLE = "<h1><center>Enhanced Image Captioning Studio</center></h1>"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Pre-defined caption types with templates
CAPTION_TYPE_MAP = {
    "Descriptive": [
        "Write a descriptive caption for this image in a formal tone.",
        "Write a descriptive caption for this image in a formal tone within {word_count} words.",
        "Write a {length} descriptive caption for this image in a formal tone.",
    ],
    "Descriptive (Informal)": [
        "Write a descriptive caption for this image in a casual tone.",
        "Write a descriptive caption for this image in a casual tone within {word_count} words.",
        "Write a {length} descriptive caption for this image in a casual tone.",
    ],
    "AI Generation Prompt": [
        "Write a detailed prompt for AI image generation based on this image.",
        "Write a detailed prompt for AI image generation based on this image within {word_count} words.",
        "Write a {length} prompt for AI image generation based on this image.",
    ],
    "MidJourney": [
        "Write a MidJourney prompt for this image.",
        "Write a MidJourney prompt for this image within {word_count} words.",
        "Write a {length} MidJourney prompt for this image.",
    ],
    "Stable Diffusion": [
        "Write a Stable Diffusion prompt for this image.",
        "Write a Stable Diffusion prompt for this image within {word_count} words.",
        "Write a {length} Stable Diffusion prompt for this image.",
    ],
    "Art Critic": [
        "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc.",
        "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it within {word_count} words.",
        "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it {length}.",
    ],
    "Product Listing": [
        "Write a caption for this image as though it were a product listing.",
        "Write a caption for this image as though it were a product listing. Keep it under {word_count} words.",
        "Write a {length} caption for this image as though it were a product listing.",
    ],
    "Social Media Post": [
        "Write a caption for this image as if it were being used for a social media post.",
        "Write a caption for this image as if it were being used for a social media post. Limit the caption to {word_count} words.",
        "Write a {length} caption for this image as if it were being used for a social media post.",
    ],
    "Tag List": [
        "Write a list of tags for this image.",
        "Write a list of tags for this image within {word_count} words.",
        "Write a {length} list of tags for this image.",
    ],
    "Technical Analysis": [
        "Provide a technical analysis of this image including camera details, lighting, composition, and quality.",
        "Provide a technical analysis of this image including camera details, lighting, composition, and quality within {word_count} words.",
        "Provide a {length} technical analysis of this image including camera details, lighting, composition, and quality.",
    ],
}

class ImageAdapter(nn.Module):
    def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool):
        super().__init__()
        self.deep_extract = deep_extract

        if self.deep_extract:
            input_features = input_features * 5

        self.linear1 = nn.Linear(input_features, output_features)
        self.activation = nn.GELU()
        self.linear2 = nn.Linear(output_features, output_features)
        self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
        self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))

        # Other tokens (<|image_start|>, <|image_end|>, <|eot_id|>)
        self.other_tokens = nn.Embedding(3, output_features)
        self.other_tokens.weight.data.normal_(mean=0.0, std=0.02)

    def forward(self, vision_outputs: torch.Tensor):
        if self.deep_extract:
            x = torch.concat((
                vision_outputs[-2],
                vision_outputs[3],
                vision_outputs[7],
                vision_outputs[13],
                vision_outputs[20],
            ), dim=-1)
            assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}"
            assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
        else:
            x = vision_outputs[-2]

        x = self.ln1(x)

        if self.pos_emb is not None:
            assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
            x = x + self.pos_emb

        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)

        # <|image_start|>, IMAGE, <|image_end|>
        other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1))
        assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
        x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)

        return x

    def get_eot_embedding(self):
        return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)

# Model loading functions
def load_siglip_model():
    print("Loading SigLIP model...")
    model_path = "google/siglip-so400m-patch14-384"
    processor = AutoProcessor.from_pretrained(model_path)
    model = AutoModel.from_pretrained(model_path)
    model = model.vision_model
    model.eval()
    model.requires_grad_(False)
    model.to(DEVICE)
    return model, processor

def load_blip_model():
    print("Loading BLIP model...")
    processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
    model = AutoModelForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
    model.to(DEVICE)
    model.eval()
    return model, processor

# Initialize models (with optional lazy loading)
class ModelManager:
    def __init__(self):
        self.blip_model = None
        self.blip_processor = None
        self.siglip_model = None
        self.siglip_processor = None
        self.image_adapter = None
        self.llm_model = None
        self.tokenizer = None
        self.models_loaded = False

    def load_models(self):
        if not self.models_loaded:
            # Load BLIP model for basic captioning
            self.blip_model, self.blip_processor = load_blip_model()
            
            # For more advanced captioning, set up paths to load custom models
            # In a real implementation, you would load the full pipeline with proper paths
            # For now, we'll use BLIP for both simple and advanced operations
            
            self.models_loaded = True
        return self

model_manager = ModelManager()

def generate_basic_caption(image, prompt="a detailed caption of this image:"):
    """Generate a basic caption using BLIP model"""
    model_manager.load_models()
    
    inputs = model_manager.blip_processor(image, prompt, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model_manager.blip_model.generate(**inputs, max_new_tokens=100)
    
    return model_manager.blip_processor.decode(outputs[0], skip_special_tokens=True)

def generate_advanced_description(image, caption_type, caption_length, detail_level, emotion_focus, style_focus, extra_options, custom_prompt):
    """Generate an advanced description using multiple targeted prompts"""
    if image is None:
        return "Please upload an image to generate a description."
    
    try:
        # Load models if not already loaded
        model_manager.load_models()
        
        # Process caption parameters
        length = None if caption_length == "any" else caption_length

        if isinstance(length, str):
            try:
                length = int(length)
            except ValueError:
                pass
        
        # Build prompt based on caption type and parameters
        if length is None:
            map_idx = 0
        elif isinstance(length, int):
            map_idx = 1
        else:
            map_idx = 2
        
        prompt_str = CAPTION_TYPE_MAP.get(caption_type, CAPTION_TYPE_MAP["Descriptive"])[map_idx]
        
        # Add extra options
        if extra_options:
            prompt_str += " " + " ".join(extra_options)
        
        # Replace placeholders in the prompt
        prompt_str = prompt_str.format(length=caption_length, word_count=caption_length)
        
        # Override with custom prompt if provided
        if custom_prompt and custom_prompt.strip():
            prompt_str = custom_prompt.strip()
            
        print(f"Using prompt: {prompt_str}")
        
        # Generate captions with different aspects based on detail level
        with torch.no_grad():
            # 1. Basic caption
            basic_caption = generate_basic_caption(image, prompt_str)
            
            descriptions = []
            descriptions.append(("Basic Caption", basic_caption))
            
            # 2. Subject description (if detail level is high enough)
            if detail_level >= 2:
                subject_prompt = "Describe the main subjects in this image with details about their appearance:"
                subject_desc = generate_basic_caption(image, subject_prompt)
                descriptions.append(("Main Subject(s)", subject_desc))
            
            # 3. Setting/background
            if detail_level >= 3:
                setting_prompt = "Describe the setting, location, and background of this image:"
                setting_desc = generate_basic_caption(image, setting_prompt)
                descriptions.append(("Setting/Background", setting_desc))
            
            # 4. Colors and visual elements
            if style_focus >= 3:
                color_prompt = "Describe the color scheme, visual composition, and artistic style of this image:"
                color_desc = generate_basic_caption(image, color_prompt)
                descriptions.append(("Visual Style/Colors", color_desc))
            
            # 5. Emotion and mood
            if emotion_focus >= 3:
                emotion_prompt = "Describe the mood, emotional tone, and atmosphere conveyed in this image:"
                emotion_desc = generate_basic_caption(image, emotion_prompt)
                descriptions.append(("Mood/Emotional Tone", emotion_desc))
            
            # 6. Lighting and time
            if detail_level >= 4 or style_focus >= 4:
                lighting_prompt = "Describe the lighting conditions and time of day in this image:"
                lighting_desc = generate_basic_caption(image, lighting_prompt)
                descriptions.append(("Lighting/Atmosphere", lighting_desc))
            
            # 7. Details and textures (only for high detail levels)
            if detail_level >= 5:
                detail_prompt = "Describe the fine details, textures, and small elements visible in this image:"
                detail_desc = generate_basic_caption(image, detail_prompt)
                descriptions.append(("Fine Details/Textures", detail_desc))
        
        # Format results
        formatted_result = ""
        
        # Add basic subject identification
        formatted_result += f"## Basic Caption:\n{basic_caption}\n\n"
        
        # Add comprehensive description section if more detailed
        if detail_level >= 2:
            formatted_result += f"## Detailed Description:\n\n"
            
            for title, desc in descriptions[1:]:  # Skip the basic caption
                formatted_result += f"**{title}:** {desc}\n\n"
        
        # Additional section for AI generation prompts if requested
        if caption_type in ["AI Generation Prompt", "MidJourney", "Stable Diffusion"]:
            # Create a condensed version for AI generation
            ai_descriptions = [basic_caption.strip(".")]
            
            for _, desc in descriptions[1:]:
                if len(desc) > 10:
                    ai_descriptions.append(desc.split(".")[0])
            
            # Create specific prompt for AI image generation
            formatted_result += "## Suggested AI Image Generation Prompt:\n\n"
            ai_prompt = ", ".join(ai_descriptions)
            
            # Add qualifiers based on settings
            qualifiers = []
            if detail_level >= 4:
                qualifiers.append("highly detailed")
                qualifiers.append("intricate")
            if emotion_focus >= 4:
                qualifiers.append("emotional")
                qualifiers.append("evocative")
            if style_focus >= 4:
                qualifiers.append("artistic composition")
                qualifiers.append("professional photography")
            
            if qualifiers:
                ai_prompt += ", " + ", ".join(qualifiers)
            
            formatted_result += ai_prompt
            
        return formatted_result
    
    except Exception as e:
        return f"Error generating description: {str(e)}"

# Create Gradio interface
with gr.Blocks(title="Enhanced Image Captioning Studio") as demo:
    gr.HTML(TITLE)
    gr.Markdown("Upload an image to generate detailed captions and descriptions tailored to your needs.")
    
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(label="Upload Image", type="pil")
            
            caption_type = gr.Dropdown(
                choices=list(CAPTION_TYPE_MAP.keys()),
                label="Caption Type",
                value="Descriptive",
            )
            
            caption_length = gr.Dropdown(
                choices=["any", "very short", "short", "medium-length", "long", "very long"] +
                        [str(i) for i in range(20, 301, 20)],
                label="Caption Length",
                value="medium-length",
            )
            
            with gr.Accordion("Advanced Settings", open=False):
                with gr.Row():
                    detail_slider = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Detail Level")
                    emotion_slider = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Emotion Focus")
                    style_slider = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Style/Artistic Focus")
            
                extra_options = gr.CheckboxGroup(
                    choices=[
                        "Include information about lighting.",
                        "Include information about camera angle.",
                        "Include information about whether there is a watermark or not.",
                        "Include information about any artifacts or quality issues.",
                        "If it is a photo, include likely camera details such as aperture, shutter speed, ISO, etc.",
                        "Do NOT include anything sexual; keep it PG.",
                        "Do NOT mention the image's resolution.",
                        "Include information about the subjective aesthetic quality of the image.",
                        "Include information on the image's composition style.",
                        "Do NOT mention any text that is in the image.",
                        "Specify the depth of field and focus.",
                        "Mention the likely use of artificial or natural lighting sources.",
                        "ONLY describe the most important elements of the image."
                    ],
                    label="Additional Options"
                )
                
                custom_prompt = gr.Textbox(label="Custom Prompt (optional, will override other settings)")
                gr.Markdown("**Note:** Custom prompts may not work with all models and settings.")
            
            generate_btn = gr.Button("Generate Description", variant="primary")
        
        with gr.Column(scale=1):
            output_text = gr.Textbox(label="Generated Description", lines=25)
    
    # Set up event handlers
    generate_btn.click(
        fn=generate_advanced_description, 
        inputs=[
            input_image, 
            caption_type, 
            caption_length, 
            detail_slider, 
            emotion_slider, 
            style_slider, 
            extra_options, 
            custom_prompt
        ], 
        outputs=output_text
    )
    
    gr.Markdown("""
    ## How to Use
    1. Upload an image
    2. Select the type of caption you want
    3. Choose a length preference
    4. Adjust advanced settings if needed:
       - Detail Level: Controls the comprehensiveness of the description
       - Emotion Focus: Emphasizes mood and feelings in the output
       - Style Focus: Emphasizes artistic elements in the output
    5. Select any additional options you'd like included
    6. Click "Generate Description"
    
    ## About
    This application combines multiple image analysis techniques to generate rich, 
    detailed descriptions of images. It's especially useful for creating prompts 
    for AI image generators like Stable Diffusion, Midjourney, or DALL-E.
    """)

# Launch the app
if __name__ == "__main__":
    demo.launch()