import gradio as gr from diffusers import StableDiffusionInstructPix2PixPipeline from transformers import YolosImageProcessor, YolosForObjectDetection, BlipProcessor, BlipForConditionalGeneration from PIL import Image, ImageDraw, ImageFont import torch import json # Global models pipe = None detector = None detector_processor = None captioner = None caption_processor = None # Dynamic color generator def generate_color(text): """Generate consistent color from text using hash""" hash_val = hash(text) % 360 return f"hsl({hash_val}, 70%, 55%)" # Dynamic category storage DETECTED_CATEGORIES = {} def load_models(): """Load all models""" global pipe, detector, detector_processor, captioner, caption_processor if pipe is None: print("Loading image editor...") pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( "timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None ) pipe.to("cuda" if torch.cuda.is_available() else "cpu") if detector is None: print("Loading object detector...") detector_processor = YolosImageProcessor.from_pretrained('hustvl/yolos-tiny') detector = YolosForObjectDetection.from_pretrained('hustvl/yolos-tiny') detector.to("cuda" if torch.cuda.is_available() else "cpu") if captioner is None: print("Loading image captioner...") caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") captioner = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") captioner.to("cuda" if torch.cuda.is_available() else "cpu") print("All models loaded!") def detect_objects(image): """Detect objects in image with detailed info""" load_models() try: # Detect objects inputs = detector_processor(images=image, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} outputs = detector(**inputs) target_sizes = torch.tensor([image.size[::-1]]) results = detector_processor.post_process_object_detection(outputs, threshold=0.3, target_sizes=target_sizes)[0] # Draw on image draw = ImageDraw.Draw(image) try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) except: font = ImageFont.load_default() detections = [] for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): box = [round(i, 2) for i in box.tolist()] label_name = detector.config.id2label[label.item()] confidence = round(score.item(), 3) # Auto-generate category and color category = label_name # Use the label itself as category color = generate_color(label_name) # Store in dynamic dict if category not in DETECTED_CATEGORIES: DETECTED_CATEGORIES[category] = color # Draw box draw.rectangle(box, outline=color, width=3) # Draw label background text = f"{label_name} {confidence:.0%}" bbox = draw.textbbox((box[0], box[1]-20), text, font=font) draw.rectangle([bbox[0]-2, bbox[1]-2, bbox[2]+2, bbox[3]+2], fill=color) draw.text((box[0], box[1]-20), text, fill='white', font=font) # Get specific info about this object obj_image = image.crop(box) obj_info = get_detailed_info(obj_image, label_name) detections.append({ 'label': label_name, 'category': category, 'confidence': f"{confidence:.1%}", 'bbox': box, 'color': color, 'details': obj_info }) # Create HTML output with clickable objects html_output = create_detection_html(detections) return image, html_output, json.dumps(detections, indent=2) except Exception as e: print(f"Detection error: {e}") import traceback traceback.print_exc() return image, f"

Error: {str(e)}

", "{}" def get_detailed_info(obj_image, label): """Get detailed description of the detected object""" try: # Generate caption for the object inputs = caption_processor(obj_image, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} out = captioner.generate(**inputs, max_length=50) caption = caption_processor.decode(out[0], skip_special_tokens=True) # Create search URL search_query = f"{label} {caption}".replace(' ', '+') search_url = f"https://www.google.com/search?q={search_query}" return { 'description': caption, 'search_url': search_url } except: search_url = f"https://www.google.com/search?q={label.replace(' ', '+')}" return { 'description': f"A {label}", 'search_url': search_url } def create_detection_html(detections): """Create interactive HTML with clickable detections""" if not detections: return "

No objects detected

" html = """
""" # Group by category by_category = {} for det in detections: cat = det['category'] if cat not in by_category: by_category[cat] = [] by_category[cat].append(det) for category, items in by_category.items(): color = generate_color(category) html += f"

{category}s ({len(items)})

" for det in items: html += f"""
{det['label']}
{det['category']} Confidence: {det['confidence']}
{det['details']['description']}
🔍 Learn more about this {det['label']}
""" html += "
" return html def edit_image(input_image, edit_prompt, num_steps, guidance_scale, image_guidance_scale): """Edit image""" if input_image is None or not edit_prompt.strip(): return None, "❌ Provide image and prompt!" try: load_models() # Resize max_size = 512 if max(input_image.size) > max_size: ratio = max_size / max(input_image.size) new_size = tuple(int(dim * ratio) for dim in input_image.size) input_image = input_image.resize(new_size, Image.Resampling.LANCZOS) width = (input_image.width // 8) * 8 height = (input_image.height // 8) * 8 input_image = input_image.resize((width, height)) result = pipe( edit_prompt, image=input_image, num_inference_steps=num_steps, guidance_scale=guidance_scale, image_guidance_scale=image_guidance_scale, ).images[0] return result, "✅ Done!" except Exception as e: return None, f"❌ Error: {str(e)}" # Build interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🎨 AI Image Editor & Object Detector") with gr.Tabs(): with gr.Tab("🔍 Detect Objects"): gr.Markdown("Upload an image to detect and identify objects with detailed information") with gr.Row(): with gr.Column(): detect_input = gr.Image(label="Upload Image", type="pil") detect_btn = gr.Button("🔍 Detect Objects", variant="primary", size="lg") with gr.Column(): detect_output = gr.Image(label="Detected Objects") detection_info = gr.HTML(label="Object Details (Click to learn more)") detection_json = gr.JSON(label="Detection Data", visible=False) detect_btn.click( fn=detect_objects, inputs=[detect_input], outputs=[detect_output, detection_info, detection_json] ) with gr.Tab("✏️ Edit Image"): gr.Markdown("Edit images with text instructions") with gr.Row(): with gr.Column(): edit_input = gr.Image(label="Upload Image", type="pil") edit_prompt = gr.Textbox( label="Instructions", placeholder="make it a painting, add snow, turn day into night...", lines=2 ) with gr.Accordion("Settings", open=False): num_steps = gr.Slider(10, 50, value=20, step=5, label="Steps") guidance_scale = gr.Slider(1, 10, value=7.5, step=0.5, label="Text Guidance") image_guidance_scale = gr.Slider(1, 2, value=1.5, step=0.1, label="Image Guidance") edit_btn = gr.Button("✨ Edit", variant="primary") with gr.Column(): edit_output = gr.Image(label="Result") edit_status = gr.Textbox(label="Status", interactive=False) edit_btn.click( fn=edit_image, inputs=[edit_input, edit_prompt, num_steps, guidance_scale, image_guidance_scale], outputs=[edit_output, edit_status] ) gr.Markdown(""" ### 🎯 Features: - **Object Detection**: Identifies objects with bounding boxes and confidence scores - **Categories**: Color-coded by type (vehicles, animals, people, etc.) - **Detailed Info**: AI-generated descriptions for each object - **Clickable Links**: Click any object to learn more about it - **Image Editing**: Transform images with simple text instructions """) demo.launch()