|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor |
|
|
from PIL import Image |
|
|
import json |
|
|
|
|
|
|
|
|
try: |
|
|
from qwen_vl_utils import process_vision_info |
|
|
QWEN_UTILS_AVAILABLE = True |
|
|
except ImportError: |
|
|
print("Warning: qwen_vl_utils not available, using fallback processing") |
|
|
QWEN_UTILS_AVAILABLE = False |
|
|
|
|
|
|
|
|
model = None |
|
|
processor = None |
|
|
tokenizer = None |
|
|
|
|
|
def process_vision_info_fallback(messages): |
|
|
"""Fallback function if qwen_vl_utils is not available""" |
|
|
image_inputs = [] |
|
|
video_inputs = [] |
|
|
|
|
|
for message in messages: |
|
|
if message.get("role") == "user": |
|
|
for content in message.get("content", []): |
|
|
if content.get("type") == "image": |
|
|
image_inputs.append(content["image"]) |
|
|
elif content.get("type") == "video": |
|
|
video_inputs.append(content["video"]) |
|
|
|
|
|
return image_inputs, video_inputs |
|
|
|
|
|
def load_model(): |
|
|
"""Load the Qwen2.5-VL model and processor with better error handling""" |
|
|
global model, processor, tokenizer |
|
|
|
|
|
if model is None: |
|
|
try: |
|
|
print("Loading Qwen2.5-VL-7B-Instruct model...") |
|
|
|
|
|
|
|
|
model_id = "Qwen/Qwen2.5-VL-7B-Instruct" |
|
|
|
|
|
|
|
|
print("Loading processor...") |
|
|
processor = AutoProcessor.from_pretrained( |
|
|
model_id, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
print("Loading tokenizer...") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_id, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
print("Loading model... This may take a few minutes...") |
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
|
|
|
attn_implementation="eager", |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
|
|
|
print("Model loaded successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading main model: {e}") |
|
|
print("Trying alternative loading method...") |
|
|
|
|
|
try: |
|
|
|
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="cpu", |
|
|
trust_remote_code=True, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
print("Model loaded with fallback method!") |
|
|
|
|
|
except Exception as e2: |
|
|
print(f"Fallback loading also failed: {e2}") |
|
|
print("Trying smaller Qwen2-VL model...") |
|
|
|
|
|
try: |
|
|
|
|
|
model_id = "Qwen/Qwen2-VL-7B-Instruct" |
|
|
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) |
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16, |
|
|
device_map="auto", |
|
|
trust_remote_code=True, |
|
|
) |
|
|
print("Loaded Qwen2-VL (older version) successfully!") |
|
|
|
|
|
except Exception as e3: |
|
|
raise Exception(f"All model loading attempts failed. Last error: {e3}") |
|
|
|
|
|
return model, processor, tokenizer |
|
|
|
|
|
def generate_metadata(image, metadata_type): |
|
|
"""Generate metadata for the uploaded image with improved error handling""" |
|
|
if image is None: |
|
|
return "Please upload an image first." |
|
|
|
|
|
try: |
|
|
|
|
|
model, processor, tokenizer = load_model() |
|
|
|
|
|
|
|
|
prompts = { |
|
|
"Basic Description": "Describe this image in detail, including what you see, the setting, colors, and overall composition.", |
|
|
"Technical Analysis": "Analyze this image from a technical perspective. Describe the lighting, composition, camera angle, depth of field, and any photographic techniques used.", |
|
|
"Objects & People": "List all the objects, people, animals, and items you can identify in this image. Be comprehensive and specific.", |
|
|
"Scene & Context": "Describe the scene, setting, location, time of day, weather conditions, and any contextual information you can infer from this image.", |
|
|
"Artistic Analysis": "Analyze this image from an artistic perspective, discussing the style, mood, aesthetic qualities, visual elements, and artistic techniques used.", |
|
|
"SEO Keywords": "Generate relevant SEO keywords and tags that would help categorize and find this image in a database or search system.", |
|
|
"JSON Metadata": "Create a comprehensive JSON metadata object for this image including description, objects, colors, setting, mood, and technical details." |
|
|
} |
|
|
|
|
|
prompt = prompts.get(metadata_type, prompts["Basic Description"]) |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": image, |
|
|
}, |
|
|
{"type": "text", "text": prompt}, |
|
|
], |
|
|
} |
|
|
] |
|
|
|
|
|
|
|
|
try: |
|
|
text = processor.apply_chat_template( |
|
|
messages, tokenize=False, add_generation_prompt=True |
|
|
) |
|
|
|
|
|
|
|
|
if QWEN_UTILS_AVAILABLE: |
|
|
image_inputs, video_inputs = process_vision_info(messages) |
|
|
else: |
|
|
image_inputs, video_inputs = process_vision_info_fallback(messages) |
|
|
|
|
|
inputs = processor( |
|
|
text=[text], |
|
|
images=image_inputs, |
|
|
videos=video_inputs, |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
|
|
|
inputs = inputs.to(model.device) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in input processing: {e}") |
|
|
|
|
|
try: |
|
|
inputs = processor( |
|
|
text=prompt, |
|
|
images=image, |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
) |
|
|
inputs = inputs.to(model.device) |
|
|
except Exception as e2: |
|
|
return f"Error processing input: {str(e2)}" |
|
|
|
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
generated_ids = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=384, |
|
|
temperature=0.7, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
) |
|
|
|
|
|
|
|
|
generated_ids_trimmed = [ |
|
|
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) |
|
|
] |
|
|
output_text = processor.batch_decode( |
|
|
generated_ids_trimmed, |
|
|
skip_special_tokens=True, |
|
|
clean_up_tokenization_spaces=False |
|
|
)[0] |
|
|
|
|
|
return output_text.strip() |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error during generation: {str(e)}" |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error generating metadata: {str(e)}" |
|
|
|
|
|
def create_interface(): |
|
|
"""Create the Gradio interface""" |
|
|
|
|
|
css = """ |
|
|
.metadata-container { |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
border-radius: 15px; |
|
|
padding: 20px; |
|
|
margin: 10px 0; |
|
|
} |
|
|
.output-text { |
|
|
background-color: #f8f9fa; |
|
|
border-radius: 10px; |
|
|
padding: 15px; |
|
|
border-left: 4px solid #667eea; |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=css, title="Image Metadata Generator with Qwen2.5-VL") as interface: |
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; padding: 20px;"> |
|
|
<h1 style="color: #333; margin-bottom: 10px;">🖼️ Image Metadata Generator</h1> |
|
|
<h3 style="color: #666; font-weight: normal;">Powered by Qwen2.5-VL</h3> |
|
|
<p style="color: #888;">Upload an image and generate comprehensive metadata using AI vision</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
image_input = gr.Image( |
|
|
type="pil", |
|
|
label="Upload Image", |
|
|
height=400 |
|
|
) |
|
|
|
|
|
metadata_type = gr.Dropdown( |
|
|
choices=[ |
|
|
"Basic Description", |
|
|
"Technical Analysis", |
|
|
"Objects & People", |
|
|
"Scene & Context", |
|
|
"Artistic Analysis", |
|
|
"SEO Keywords", |
|
|
"JSON Metadata" |
|
|
], |
|
|
value="Basic Description", |
|
|
label="Metadata Type" |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button( |
|
|
"Generate Metadata", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
output_text = gr.Textbox( |
|
|
label="Generated Metadata", |
|
|
lines=20, |
|
|
max_lines=25, |
|
|
elem_classes=["output-text"] |
|
|
) |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_metadata, |
|
|
inputs=[image_input, metadata_type], |
|
|
outputs=output_text, |
|
|
show_progress=True |
|
|
) |
|
|
|
|
|
|
|
|
image_input.change( |
|
|
fn=lambda img: generate_metadata(img, "Basic Description") if img else "", |
|
|
inputs=[image_input], |
|
|
outputs=output_text, |
|
|
show_progress=True |
|
|
) |
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; padding: 20px; margin-top: 30px; border-top: 1px solid #eee;"> |
|
|
<p style="color: #666;"> |
|
|
This Space uses Qwen2.5-VL for intelligent image analysis and metadata generation. |
|
|
<br>Perfect for content management, SEO optimization, and accessibility improvements. |
|
|
</p> |
|
|
<p style="color: #888; font-size: 0.9em; margin-top: 10px;"> |
|
|
<strong>Note:</strong> First generation may take 1-2 minutes while the model loads. Subsequent generations will be much faster. |
|
|
</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
return interface |
|
|
|
|
|
def initialize_app(): |
|
|
"""Initialize the application""" |
|
|
print("Starting Image Metadata Generator...") |
|
|
print("Model will be loaded on first use to save resources.") |
|
|
|
|
|
|
|
|
print(f"PyTorch version: {torch.__version__}") |
|
|
print(f"CUDA available: {torch.cuda.is_available()}") |
|
|
if torch.cuda.is_available(): |
|
|
print(f"CUDA device: {torch.cuda.get_device_name(0)}") |
|
|
|
|
|
interface = create_interface() |
|
|
return interface |
|
|
|
|
|
if __name__ == "__main__": |
|
|
app = initialize_app() |
|
|
app.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
share=False |
|
|
) |
|
|
|