import gradio as gr
import torch
from PIL import Image
import torch.nn as nn
import torchvision.transforms.functional as TVF
from transformers import AutoModel, AutoProcessor, AutoTokenizer, AutoModelForConditionalGeneration, PreTrainedTokenizer, PreTrainedTokenizerFast
# Define constants
TITLE = "
Enhanced Image Captioning Studio
"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Pre-defined caption types with templates
CAPTION_TYPE_MAP = {
"Descriptive": [
"Write a descriptive caption for this image in a formal tone.",
"Write a descriptive caption for this image in a formal tone within {word_count} words.",
"Write a {length} descriptive caption for this image in a formal tone.",
],
"Descriptive (Informal)": [
"Write a descriptive caption for this image in a casual tone.",
"Write a descriptive caption for this image in a casual tone within {word_count} words.",
"Write a {length} descriptive caption for this image in a casual tone.",
],
"AI Generation Prompt": [
"Write a detailed prompt for AI image generation based on this image.",
"Write a detailed prompt for AI image generation based on this image within {word_count} words.",
"Write a {length} prompt for AI image generation based on this image.",
],
"MidJourney": [
"Write a MidJourney prompt for this image.",
"Write a MidJourney prompt for this image within {word_count} words.",
"Write a {length} MidJourney prompt for this image.",
],
"Stable Diffusion": [
"Write a Stable Diffusion prompt for this image.",
"Write a Stable Diffusion prompt for this image within {word_count} words.",
"Write a {length} Stable Diffusion prompt for this image.",
],
"Art Critic": [
"Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc.",
"Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it within {word_count} words.",
"Analyze this image like an art critic would with information about its composition, style, symbolism, the use of color, light, any artistic movement it might belong to, etc. Keep it {length}.",
],
"Product Listing": [
"Write a caption for this image as though it were a product listing.",
"Write a caption for this image as though it were a product listing. Keep it under {word_count} words.",
"Write a {length} caption for this image as though it were a product listing.",
],
"Social Media Post": [
"Write a caption for this image as if it were being used for a social media post.",
"Write a caption for this image as if it were being used for a social media post. Limit the caption to {word_count} words.",
"Write a {length} caption for this image as if it were being used for a social media post.",
],
"Tag List": [
"Write a list of tags for this image.",
"Write a list of tags for this image within {word_count} words.",
"Write a {length} list of tags for this image.",
],
"Technical Analysis": [
"Provide a technical analysis of this image including camera details, lighting, composition, and quality.",
"Provide a technical analysis of this image including camera details, lighting, composition, and quality within {word_count} words.",
"Provide a {length} technical analysis of this image including camera details, lighting, composition, and quality.",
],
}
class ImageAdapter(nn.Module):
def __init__(self, input_features: int, output_features: int, ln1: bool, pos_emb: bool, num_image_tokens: int, deep_extract: bool):
super().__init__()
self.deep_extract = deep_extract
if self.deep_extract:
input_features = input_features * 5
self.linear1 = nn.Linear(input_features, output_features)
self.activation = nn.GELU()
self.linear2 = nn.Linear(output_features, output_features)
self.ln1 = nn.Identity() if not ln1 else nn.LayerNorm(input_features)
self.pos_emb = None if not pos_emb else nn.Parameter(torch.zeros(num_image_tokens, input_features))
# Other tokens (<|image_start|>, <|image_end|>, <|eot_id|>)
self.other_tokens = nn.Embedding(3, output_features)
self.other_tokens.weight.data.normal_(mean=0.0, std=0.02)
def forward(self, vision_outputs: torch.Tensor):
if self.deep_extract:
x = torch.concat((
vision_outputs[-2],
vision_outputs[3],
vision_outputs[7],
vision_outputs[13],
vision_outputs[20],
), dim=-1)
assert len(x.shape) == 3, f"Expected 3, got {len(x.shape)}"
assert x.shape[-1] == vision_outputs[-2].shape[-1] * 5, f"Expected {vision_outputs[-2].shape[-1] * 5}, got {x.shape[-1]}"
else:
x = vision_outputs[-2]
x = self.ln1(x)
if self.pos_emb is not None:
assert x.shape[-2:] == self.pos_emb.shape, f"Expected {self.pos_emb.shape}, got {x.shape[-2:]}"
x = x + self.pos_emb
x = self.linear1(x)
x = self.activation(x)
x = self.linear2(x)
# <|image_start|>, IMAGE, <|image_end|>
other_tokens = self.other_tokens(torch.tensor([0, 1], device=self.other_tokens.weight.device).expand(x.shape[0], -1))
assert other_tokens.shape == (x.shape[0], 2, x.shape[2]), f"Expected {(x.shape[0], 2, x.shape[2])}, got {other_tokens.shape}"
x = torch.cat((other_tokens[:, 0:1], x, other_tokens[:, 1:2]), dim=1)
return x
def get_eot_embedding(self):
return self.other_tokens(torch.tensor([2], device=self.other_tokens.weight.device)).squeeze(0)
# Model loading functions
def load_siglip_model():
print("Loading SigLIP model...")
model_path = "google/siglip-so400m-patch14-384"
processor = AutoProcessor.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path)
model = model.vision_model
model.eval()
model.requires_grad_(False)
model.to(DEVICE)
return model, processor
def load_blip_model():
print("Loading BLIP model...")
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = AutoModelForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
model.to(DEVICE)
model.eval()
return model, processor
# Initialize models (with optional lazy loading)
class ModelManager:
def __init__(self):
self.blip_model = None
self.blip_processor = None
self.siglip_model = None
self.siglip_processor = None
self.image_adapter = None
self.llm_model = None
self.tokenizer = None
self.models_loaded = False
def load_models(self):
if not self.models_loaded:
# Load BLIP model for basic captioning
self.blip_model, self.blip_processor = load_blip_model()
# For more advanced captioning, set up paths to load custom models
# In a real implementation, you would load the full pipeline with proper paths
# For now, we'll use BLIP for both simple and advanced operations
self.models_loaded = True
return self
model_manager = ModelManager()
def generate_basic_caption(image, prompt="a detailed caption of this image:"):
"""Generate a basic caption using BLIP model"""
model_manager.load_models()
inputs = model_manager.blip_processor(image, prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model_manager.blip_model.generate(**inputs, max_new_tokens=100)
return model_manager.blip_processor.decode(outputs[0], skip_special_tokens=True)
def generate_advanced_description(image, caption_type, caption_length, detail_level, emotion_focus, style_focus, extra_options, custom_prompt):
"""Generate an advanced description using multiple targeted prompts"""
if image is None:
return "Please upload an image to generate a description."
try:
# Load models if not already loaded
model_manager.load_models()
# Process caption parameters
length = None if caption_length == "any" else caption_length
if isinstance(length, str):
try:
length = int(length)
except ValueError:
pass
# Build prompt based on caption type and parameters
if length is None:
map_idx = 0
elif isinstance(length, int):
map_idx = 1
else:
map_idx = 2
prompt_str = CAPTION_TYPE_MAP.get(caption_type, CAPTION_TYPE_MAP["Descriptive"])[map_idx]
# Add extra options
if extra_options:
prompt_str += " " + " ".join(extra_options)
# Replace placeholders in the prompt
prompt_str = prompt_str.format(length=caption_length, word_count=caption_length)
# Override with custom prompt if provided
if custom_prompt and custom_prompt.strip():
prompt_str = custom_prompt.strip()
print(f"Using prompt: {prompt_str}")
# Generate captions with different aspects based on detail level
with torch.no_grad():
# 1. Basic caption
basic_caption = generate_basic_caption(image, prompt_str)
descriptions = []
descriptions.append(("Basic Caption", basic_caption))
# 2. Subject description (if detail level is high enough)
if detail_level >= 2:
subject_prompt = "Describe the main subjects in this image with details about their appearance:"
subject_desc = generate_basic_caption(image, subject_prompt)
descriptions.append(("Main Subject(s)", subject_desc))
# 3. Setting/background
if detail_level >= 3:
setting_prompt = "Describe the setting, location, and background of this image:"
setting_desc = generate_basic_caption(image, setting_prompt)
descriptions.append(("Setting/Background", setting_desc))
# 4. Colors and visual elements
if style_focus >= 3:
color_prompt = "Describe the color scheme, visual composition, and artistic style of this image:"
color_desc = generate_basic_caption(image, color_prompt)
descriptions.append(("Visual Style/Colors", color_desc))
# 5. Emotion and mood
if emotion_focus >= 3:
emotion_prompt = "Describe the mood, emotional tone, and atmosphere conveyed in this image:"
emotion_desc = generate_basic_caption(image, emotion_prompt)
descriptions.append(("Mood/Emotional Tone", emotion_desc))
# 6. Lighting and time
if detail_level >= 4 or style_focus >= 4:
lighting_prompt = "Describe the lighting conditions and time of day in this image:"
lighting_desc = generate_basic_caption(image, lighting_prompt)
descriptions.append(("Lighting/Atmosphere", lighting_desc))
# 7. Details and textures (only for high detail levels)
if detail_level >= 5:
detail_prompt = "Describe the fine details, textures, and small elements visible in this image:"
detail_desc = generate_basic_caption(image, detail_prompt)
descriptions.append(("Fine Details/Textures", detail_desc))
# Format results
formatted_result = ""
# Add basic subject identification
formatted_result += f"## Basic Caption:\n{basic_caption}\n\n"
# Add comprehensive description section if more detailed
if detail_level >= 2:
formatted_result += f"## Detailed Description:\n\n"
for title, desc in descriptions[1:]: # Skip the basic caption
formatted_result += f"**{title}:** {desc}\n\n"
# Additional section for AI generation prompts if requested
if caption_type in ["AI Generation Prompt", "MidJourney", "Stable Diffusion"]:
# Create a condensed version for AI generation
ai_descriptions = [basic_caption.strip(".")]
for _, desc in descriptions[1:]:
if len(desc) > 10:
ai_descriptions.append(desc.split(".")[0])
# Create specific prompt for AI image generation
formatted_result += "## Suggested AI Image Generation Prompt:\n\n"
ai_prompt = ", ".join(ai_descriptions)
# Add qualifiers based on settings
qualifiers = []
if detail_level >= 4:
qualifiers.append("highly detailed")
qualifiers.append("intricate")
if emotion_focus >= 4:
qualifiers.append("emotional")
qualifiers.append("evocative")
if style_focus >= 4:
qualifiers.append("artistic composition")
qualifiers.append("professional photography")
if qualifiers:
ai_prompt += ", " + ", ".join(qualifiers)
formatted_result += ai_prompt
return formatted_result
except Exception as e:
return f"Error generating description: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="Enhanced Image Captioning Studio") as demo:
gr.HTML(TITLE)
gr.Markdown("Upload an image to generate detailed captions and descriptions tailored to your needs.")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Upload Image", type="pil")
caption_type = gr.Dropdown(
choices=list(CAPTION_TYPE_MAP.keys()),
label="Caption Type",
value="Descriptive",
)
caption_length = gr.Dropdown(
choices=["any", "very short", "short", "medium-length", "long", "very long"] +
[str(i) for i in range(20, 301, 20)],
label="Caption Length",
value="medium-length",
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
detail_slider = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Detail Level")
emotion_slider = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Emotion Focus")
style_slider = gr.Slider(minimum=1, maximum=5, value=3, step=1, label="Style/Artistic Focus")
extra_options = gr.CheckboxGroup(
choices=[
"Include information about lighting.",
"Include information about camera angle.",
"Include information about whether there is a watermark or not.",
"Include information about any artifacts or quality issues.",
"If it is a photo, include likely camera details such as aperture, shutter speed, ISO, etc.",
"Do NOT include anything sexual; keep it PG.",
"Do NOT mention the image's resolution.",
"Include information about the subjective aesthetic quality of the image.",
"Include information on the image's composition style.",
"Do NOT mention any text that is in the image.",
"Specify the depth of field and focus.",
"Mention the likely use of artificial or natural lighting sources.",
"ONLY describe the most important elements of the image."
],
label="Additional Options"
)
custom_prompt = gr.Textbox(label="Custom Prompt (optional, will override other settings)")
gr.Markdown("**Note:** Custom prompts may not work with all models and settings.")
generate_btn = gr.Button("Generate Description", variant="primary")
with gr.Column(scale=1):
output_text = gr.Textbox(label="Generated Description", lines=25)
# Set up event handlers
generate_btn.click(
fn=generate_advanced_description,
inputs=[
input_image,
caption_type,
caption_length,
detail_slider,
emotion_slider,
style_slider,
extra_options,
custom_prompt
],
outputs=output_text
)
gr.Markdown("""
## How to Use
1. Upload an image
2. Select the type of caption you want
3. Choose a length preference
4. Adjust advanced settings if needed:
- Detail Level: Controls the comprehensiveness of the description
- Emotion Focus: Emphasizes mood and feelings in the output
- Style Focus: Emphasizes artistic elements in the output
5. Select any additional options you'd like included
6. Click "Generate Description"
## About
This application combines multiple image analysis techniques to generate rich,
detailed descriptions of images. It's especially useful for creating prompts
for AI image generators like Stable Diffusion, Midjourney, or DALL-E.
""")
# Launch the app
if __name__ == "__main__":
demo.launch()