markpbaggett's picture
Fix order issue.
c79cf18
import gradio as gr
import torch
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from PIL import Image
import json
# Try to import qwen_vl_utils, fallback if not available
try:
from qwen_vl_utils import process_vision_info
QWEN_UTILS_AVAILABLE = True
except ImportError:
print("Warning: qwen_vl_utils not available, using fallback processing")
QWEN_UTILS_AVAILABLE = False
# Global variables to store model and processor
model = None
processor = None
tokenizer = None
def process_vision_info_fallback(messages):
"""Fallback function if qwen_vl_utils is not available"""
image_inputs = []
video_inputs = []
for message in messages:
if message.get("role") == "user":
for content in message.get("content", []):
if content.get("type") == "image":
image_inputs.append(content["image"])
elif content.get("type") == "video":
video_inputs.append(content["video"])
return image_inputs, video_inputs
def load_model():
"""Load the Qwen2.5-VL model and processor with better error handling"""
global model, processor, tokenizer
if model is None:
try:
print("Loading Qwen2.5-VL-7B-Instruct model...")
# Try different model loading strategies
model_id = "Qwen/Qwen2.5-VL-7B-Instruct"
# Load processor first (often more stable)
print("Loading processor...")
processor = AutoProcessor.from_pretrained(
model_id,
trust_remote_code=True
)
# Load tokenizer
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
# Load model with more conservative settings
print("Loading model... This may take a few minutes...")
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
# Use eager attention (more compatible)
attn_implementation="eager",
low_cpu_mem_usage=True,
)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading main model: {e}")
print("Trying alternative loading method...")
try:
# Fallback: try loading with different parameters
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16, # Try float16 instead
device_map="cpu", # Force CPU loading
trust_remote_code=True,
low_cpu_mem_usage=True,
)
print("Model loaded with fallback method!")
except Exception as e2:
print(f"Fallback loading also failed: {e2}")
print("Trying smaller Qwen2-VL model...")
try:
# Try the older Qwen2-VL model as final fallback
model_id = "Qwen/Qwen2-VL-7B-Instruct"
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
print("Loaded Qwen2-VL (older version) successfully!")
except Exception as e3:
raise Exception(f"All model loading attempts failed. Last error: {e3}")
return model, processor, tokenizer
def generate_metadata(image, metadata_type):
"""Generate metadata for the uploaded image with improved error handling"""
if image is None:
return "Please upload an image first."
try:
# Load model if not already loaded
model, processor, tokenizer = load_model()
# Define prompts for different metadata types
prompts = {
"Basic Description": "Describe this image in detail, including what you see, the setting, colors, and overall composition.",
"Technical Analysis": "Analyze this image from a technical perspective. Describe the lighting, composition, camera angle, depth of field, and any photographic techniques used.",
"Objects & People": "List all the objects, people, animals, and items you can identify in this image. Be comprehensive and specific.",
"Scene & Context": "Describe the scene, setting, location, time of day, weather conditions, and any contextual information you can infer from this image.",
"Artistic Analysis": "Analyze this image from an artistic perspective, discussing the style, mood, aesthetic qualities, visual elements, and artistic techniques used.",
"SEO Keywords": "Generate relevant SEO keywords and tags that would help categorize and find this image in a database or search system.",
"JSON Metadata": "Create a comprehensive JSON metadata object for this image including description, objects, colors, setting, mood, and technical details."
}
prompt = prompts.get(metadata_type, prompts["Basic Description"])
# Prepare the conversation format
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{"type": "text", "text": prompt},
],
}
]
# Process the input with error handling
try:
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Use appropriate vision processing
if QWEN_UTILS_AVAILABLE:
image_inputs, video_inputs = process_vision_info(messages)
else:
image_inputs, video_inputs = process_vision_info_fallback(messages)
inputs = processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
# Move to device
inputs = inputs.to(model.device)
except Exception as e:
print(f"Error in input processing: {e}")
# Fallback to simpler processing
try:
inputs = processor(
text=prompt,
images=image,
return_tensors="pt",
padding=True
)
inputs = inputs.to(model.device)
except Exception as e2:
return f"Error processing input: {str(e2)}"
# Generate response with conservative parameters
try:
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=384, # Reduced from 512
temperature=0.7,
do_sample=True,
top_p=0.9,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Extract and decode the response
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
return output_text.strip()
except Exception as e:
return f"Error during generation: {str(e)}"
except Exception as e:
return f"Error generating metadata: {str(e)}"
def create_interface():
"""Create the Gradio interface"""
css = """
.metadata-container {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border-radius: 15px;
padding: 20px;
margin: 10px 0;
}
.output-text {
background-color: #f8f9fa;
border-radius: 10px;
padding: 15px;
border-left: 4px solid #667eea;
}
"""
with gr.Blocks(css=css, title="Image Metadata Generator with Qwen2.5-VL") as interface:
gr.HTML("""
<div style="text-align: center; padding: 20px;">
<h1 style="color: #333; margin-bottom: 10px;">🖼️ Image Metadata Generator</h1>
<h3 style="color: #666; font-weight: normal;">Powered by Qwen2.5-VL</h3>
<p style="color: #888;">Upload an image and generate comprehensive metadata using AI vision</p>
</div>
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(
type="pil",
label="Upload Image",
height=400
)
metadata_type = gr.Dropdown(
choices=[
"Basic Description",
"Technical Analysis",
"Objects & People",
"Scene & Context",
"Artistic Analysis",
"SEO Keywords",
"JSON Metadata"
],
value="Basic Description",
label="Metadata Type"
)
generate_btn = gr.Button(
"Generate Metadata",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
output_text = gr.Textbox(
label="Generated Metadata",
lines=20,
max_lines=25,
elem_classes=["output-text"]
)
# Event handlers
generate_btn.click(
fn=generate_metadata,
inputs=[image_input, metadata_type],
outputs=output_text,
show_progress=True
)
# Auto-generate on image upload
image_input.change(
fn=lambda img: generate_metadata(img, "Basic Description") if img else "",
inputs=[image_input],
outputs=output_text,
show_progress=True
)
gr.HTML("""
<div style="text-align: center; padding: 20px; margin-top: 30px; border-top: 1px solid #eee;">
<p style="color: #666;">
This Space uses Qwen2.5-VL for intelligent image analysis and metadata generation.
<br>Perfect for content management, SEO optimization, and accessibility improvements.
</p>
<p style="color: #888; font-size: 0.9em; margin-top: 10px;">
<strong>Note:</strong> First generation may take 1-2 minutes while the model loads. Subsequent generations will be much faster.
</p>
</div>
""")
return interface
def initialize_app():
"""Initialize the application"""
print("Starting Image Metadata Generator...")
print("Model will be loaded on first use to save resources.")
# Print system info for debugging
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"CUDA device: {torch.cuda.get_device_name(0)}")
interface = create_interface()
return interface
if __name__ == "__main__":
app = initialize_app()
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)