import gradio as gr import torch from PIL import Image import torch.nn as nn import torchvision.transforms.functional as TVF from transformers import AutoModel, AutoProcessor, AutoTokenizer, AutoModelForConditionalGeneration, PreTrainedTokenizer, PreTrainedTokenizerFast # Define constants TITLE = "

Enhanced Image Captioning Studio

" DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Pre-defined caption types with templates CAPTION_TYPE_MAP = { "Descriptive": [ "Write a descriptive caption for this image in a formal tone.", "Write a descriptive caption for this image in a formal tone within {word_count} words.", "Write a {length} descriptive caption for this image in a formal tone.", ], "Descriptive (Informal)": [ "Write a descriptive caption for this image in a casual tone.", "Write a descriptive caption for this image in a casual tone within {word_count} words.", "Write a {length} descriptive caption for this image in a casual tone.", ], "AI Generation Prompt": [ "Write a detailed prompt for AI image generation based on this image.", "Write a detailed prompt for AI image generation based on this image within {word_count} words.", "Write a {length} prompt for AI image generation based on this image.", ], "MidJourney": [ "Write a MidJourney prompt for this image.", "Write a MidJourney prompt for this image within {word_count} words.", "Write a {length} MidJourney prompt for this image.", ], "Stable Diffusion": [ "Write a Stable Diffusion prompt for this image.", "Write a Stable Diffusion prompt for this image within {word_count} words.", "Write a {length} Stable Diffusion prompt for this image.", ], "Art Critic": [ "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc.", "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it within {word_count} words.", "Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it {length}.", ], "Product Listing": [ "Write a caption for this image as though it were a product listing.", "Write a caption for this image as though it were a product listing. Keep it under {word_count} words.", "Write a {length} caption for this image as though it were a product listing.", ], "Social Media Post": [ "Write a caption for this image as if it were being used for a social media post.", "Write a caption for this image as if it were being used for a social media post. Limit the caption to {word_count} words.", "Write a {length} caption for this image as if it were being used for a social media post.", ], "Tag List": [ "Write a list of tags for this image.", "Write a list of tags for this image within {word_count} words.", "Write a {length} list of tags for this image.", ], "Technical Analysis": [ "Provide a technical analysis of this image including camera details, lighting, composition, and quality.", "Provide a technical analysis of this image including camera details, lighting, composition, and quality within {word_count} words.", "Provide a {length} technical analysis of this image including camera details, lighting, composition, and quality.", ], } class ImageAdapter(nn.Module): def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool): super().__init__() self.deep_extract = deep_extract if self.deep_extract: input_features = input_features * 5 self.linear1 = nn.Linear(input_features, output_features) self.activation = nn.GELU() self.linear2 = nn.Linear(output_features, output_features) self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features) self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features)) # Other tokens (<|image_start|>, <|image_end|>, <|eot_id|>) self.other_tokens = nn.Embedding(3, output_features) self.other_tokens.weight.data.normal_(mean=0.0, std=0.02) def forward(self, vision_outputs: torch.Tensor): if self.deep_extract: x = torch.concat(( vision_outputs[-2], vision_outputs[3], vision_outputs[7], vision_outputs[13], vision_outputs[20], ), dim=-1) assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}" assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}" else: x = vision_outputs[-2] x = self.ln1(x) if self.pos_emb is not None: assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}" x = x + self.pos_emb x = self.linear1(x) x = self.activation(x) x = self.linear2(x) # <|image_start|>, IMAGE, <|image_end|> other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1)) assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}" x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1) return x def get_eot_embedding(self): return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0) # Model loading functions def load_siglip_model(): print("Loading SigLIP model...") model_path = "google/siglip-so400m-patch14-384" processor = AutoProcessor.from_pretrained(model_path) model = AutoModel.from_pretrained(model_path) model = model.vision_model model.eval() model.requires_grad_(False) model.to(DEVICE) return model, processor def load_blip_model(): print("Loading BLIP model...") processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large") model = AutoModelForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large") model.to(DEVICE) model.eval() return model, processor # Initialize models (with optional lazy loading) class ModelManager: def __init__(self): self.blip_model = None self.blip_processor = None self.siglip_model = None self.siglip_processor = None self.image_adapter = None self.llm_model = None self.tokenizer = None self.models_loaded = False def load_models(self): if not self.models_loaded: # Load BLIP model for basic captioning self.blip_model, self.blip_processor = load_blip_model() # For more advanced captioning, set up paths to load custom models # In a real implementation, you would load the full pipeline with proper paths # For now, we'll use BLIP for both simple and advanced operations self.models_loaded = True return self model_manager = ModelManager() def generate_basic_caption(image, prompt="a detailed caption of this image:"): """Generate a basic caption using BLIP model""" model_manager.load_models() inputs = model_manager.blip_processor(image, prompt, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model_manager.blip_model.generate(**inputs, max_new_tokens=100) return model_manager.blip_processor.decode(outputs[0], skip_special_tokens=True) def generate_advanced_description(image, caption_type, caption_length, detail_level, emotion_focus, style_focus, extra_options, custom_prompt): """Generate an advanced description using multiple targeted prompts""" if image is None: return "Please upload an image to generate a description." try: # Load models if not already loaded model_manager.load_models() # Process caption parameters length = None if caption_length == "any" else caption_length if isinstance(length, str): try: length = int(length) except ValueError: pass # Build prompt based on caption type and parameters if length is None: map_idx = 0 elif isinstance(length, int): map_idx = 1 else: map_idx = 2 prompt_str = CAPTION_TYPE_MAP.get(caption_type, CAPTION_TYPE_MAP["Descriptive"])[map_idx] # Add extra options if extra_options: prompt_str += " " + " ".join(extra_options) # Replace placeholders in the prompt prompt_str = prompt_str.format(length=caption_length, word_count=caption_length) # Override with custom prompt if provided if custom_prompt and custom_prompt.strip(): prompt_str = custom_prompt.strip() print(f"Using prompt: {prompt_str}") # Generate captions with different aspects based on detail level with torch.no_grad(): # 1. Basic caption basic_caption = generate_basic_caption(image, prompt_str) descriptions = [] descriptions.append(("Basic Caption", basic_caption)) # 2. Subject description (if detail level is high enough) if detail_level >= 2: subject_prompt = "Describe the main subjects in this image with details about their appearance:" subject_desc = generate_basic_caption(image, subject_prompt) descriptions.append(("Main Subject(s)", subject_desc)) # 3. Setting/background if detail_level >= 3: setting_prompt = "Describe the setting, location, and background of this image:" setting_desc = generate_basic_caption(image, setting_prompt) descriptions.append(("Setting/Background", setting_desc)) # 4. Colors and visual elements if style_focus >= 3: color_prompt = "Describe the color scheme, visual composition, and artistic style of this image:" color_desc = generate_basic_caption(image, color_prompt) descriptions.append(("Visual Style/Colors", color_desc)) # 5. Emotion and mood if emotion_focus >= 3: emotion_prompt = "Describe the mood, emotional tone, and atmosphere conveyed in this image:" emotion_desc = generate_basic_caption(image, emotion_prompt) descriptions.append(("Mood/Emotional Tone", emotion_desc)) # 6. Lighting and time if detail_level >= 4 or style_focus >= 4: lighting_prompt = "Describe the lighting conditions and time of day in this image:" lighting_desc = generate_basic_caption(image, lighting_prompt) descriptions.append(("Lighting/Atmosphere", lighting_desc)) # 7. Details and textures (only for high detail levels) if detail_level >= 5: detail_prompt = "Describe the fine details, textures, and small elements visible in this image:" detail_desc = generate_basic_caption(image, detail_prompt) descriptions.append(("Fine Details/Textures", detail_desc)) # Format results formatted_result = "" # Add basic subject identification formatted_result += f"## Basic Caption:\n{basic_caption}\n\n" # Add comprehensive description section if more detailed if detail_level >= 2: formatted_result += f"## Detailed Description:\n\n" for title, desc in descriptions[1:]: # Skip the basic caption formatted_result += f"**{title}:** {desc}\n\n" # Additional section for AI generation prompts if requested if caption_type in ["AI Generation Prompt", "MidJourney", "Stable Diffusion"]: # Create a condensed version for AI generation ai_descriptions = [basic_caption.strip(".")] for _, desc in descriptions[1:]: if len(desc) > 10: ai_descriptions.append(desc.split(".")[0]) # Create specific prompt for AI image generation formatted_result += "## Suggested AI Image Generation Prompt:\n\n" ai_prompt = ", ".join(ai_descriptions) # Add qualifiers based on settings qualifiers = [] if detail_level >= 4: qualifiers.append("highly detailed") qualifiers.append("intricate") if emotion_focus >= 4: qualifiers.append("emotional") qualifiers.append("evocative") if style_focus >= 4: qualifiers.append("artistic composition") qualifiers.append("professional photography") if qualifiers: ai_prompt += ", " + ", ".join(qualifiers) formatted_result += ai_prompt return formatted_result except Exception as e: return f"Error generating description: {str(e)}" # Create Gradio interface with gr.Blocks(title="Enhanced Image Captioning Studio") as demo: gr.HTML(TITLE) gr.Markdown("Upload an image to generate detailed captions and descriptions tailored to your needs.") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(label="Upload Image", type="pil") caption_type = gr.Dropdown( choices=list(CAPTION_TYPE_MAP.keys()), label="Caption Type", value="Descriptive", ) caption_length = gr.Dropdown( choices=["any", "very short", "short", "medium-length", "long", "very long"] + [str(i) for i in range(20, 301, 20)], label="Caption Length", value="medium-length", ) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): detail_slider = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Detail Level") emotion_slider = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Emotion Focus") style_slider = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Style/Artistic Focus") extra_options = gr.CheckboxGroup( choices=[ "Include information about lighting.", "Include information about camera angle.", "Include information about whether there is a watermark or not.", "Include information about any artifacts or quality issues.", "If it is a photo, include likely camera details such as aperture, shutter speed, ISO, etc.", "Do NOT include anything sexual; keep it PG.", "Do NOT mention the image's resolution.", "Include information about the subjective aesthetic quality of the image.", "Include information on the image's composition style.", "Do NOT mention any text that is in the image.", "Specify the depth of field and focus.", "Mention the likely use of artificial or natural lighting sources.", "ONLY describe the most important elements of the image." ], label="Additional Options" ) custom_prompt = gr.Textbox(label="Custom Prompt (optional, will override other settings)") gr.Markdown("**Note:** Custom prompts may not work with all models and settings.") generate_btn = gr.Button("Generate Description", variant="primary") with gr.Column(scale=1): output_text = gr.Textbox(label="Generated Description", lines=25) # Set up event handlers generate_btn.click( fn=generate_advanced_description, inputs=[ input_image, caption_type, caption_length, detail_slider, emotion_slider, style_slider, extra_options, custom_prompt ], outputs=output_text ) gr.Markdown(""" ## How to Use 1. Upload an image 2. Select the type of caption you want 3. Choose a length preference 4. Adjust advanced settings if needed: - Detail Level: Controls the comprehensiveness of the description - Emotion Focus: Emphasizes mood and feelings in the output - Style Focus: Emphasizes artistic elements in the output 5. Select any additional options you'd like included 6. Click "Generate Description" ## About This application combines multiple image analysis techniques to generate rich, detailed descriptions of images. It's especially useful for creating prompts for AI image generators like Stable Diffusion, Midjourney, or DALL-E. """) # Launch the app if __name__ == "__main__": demo.launch()