import gradio as gr from diffusers import StableDiffusionInstructPix2PixPipeline from transformers import YolosImageProcessor, YolosForObjectDetection, BlipProcessor, BlipForConditionalGeneration from PIL import Image, ImageDraw, ImageFont import torch import json # Global models pipe = None detector = None detector_processor = None captioner = None caption_processor = None # Dynamic color generator def generate_color(text): """Generate consistent color from text using hash""" hash_val = hash(text) % 360 return f"hsl({hash_val}, 70%, 55%)" # Dynamic category storage DETECTED_CATEGORIES = {} def load_models(): """Load all models""" global pipe, detector, detector_processor, captioner, caption_processor if pipe is None: print("Loading image editor...") pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( "timbrooks/instruct-pix2pix", torch_dtype=torch.float16, safety_checker=None ) pipe.to("cuda" if torch.cuda.is_available() else "cpu") if detector is None: print("Loading object detector...") detector_processor = YolosImageProcessor.from_pretrained('hustvl/yolos-tiny') detector = YolosForObjectDetection.from_pretrained('hustvl/yolos-tiny') detector.to("cuda" if torch.cuda.is_available() else "cpu") if captioner is None: print("Loading image captioner...") caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") captioner = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") captioner.to("cuda" if torch.cuda.is_available() else "cpu") print("All models loaded!") def detect_objects(image): """Detect objects in image with detailed info""" load_models() try: # Detect objects inputs = detector_processor(images=image, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} outputs = detector(**inputs) target_sizes = torch.tensor([image.size[::-1]]) results = detector_processor.post_process_object_detection(outputs, threshold=0.3, target_sizes=target_sizes)[0] # Draw on image draw = ImageDraw.Draw(image) try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) except: font = ImageFont.load_default() detections = [] for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): box = [round(i, 2) for i in box.tolist()] label_name = detector.config.id2label[label.item()] confidence = round(score.item(), 3) # Auto-generate category and color category = label_name # Use the label itself as category color = generate_color(label_name) # Store in dynamic dict if category not in DETECTED_CATEGORIES: DETECTED_CATEGORIES[category] = color # Draw box draw.rectangle(box, outline=color, width=3) # Draw label background text = f"{label_name} {confidence:.0%}" bbox = draw.textbbox((box[0], box[1]-20), text, font=font) draw.rectangle([bbox[0]-2, bbox[1]-2, bbox[2]+2, bbox[3]+2], fill=color) draw.text((box[0], box[1]-20), text, fill='white', font=font) # Get specific info about this object obj_image = image.crop(box) obj_info = get_detailed_info(obj_image, label_name) detections.append({ 'label': label_name, 'category': category, 'confidence': f"{confidence:.1%}", 'bbox': box, 'color': color, 'details': obj_info }) # Create HTML output with clickable objects html_output = create_detection_html(detections) return image, html_output, json.dumps(detections, indent=2) except Exception as e: print(f"Detection error: {e}") import traceback traceback.print_exc() return image, f"
Error: {str(e)}
", "{}" def get_detailed_info(obj_image, label): """Get detailed description of the detected object""" try: # Generate caption for the object inputs = caption_processor(obj_image, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.to("cuda") for k, v in inputs.items()} out = captioner.generate(**inputs, max_length=50) caption = caption_processor.decode(out[0], skip_special_tokens=True) # Create search URL search_query = f"{label} {caption}".replace(' ', '+') search_url = f"https://www.google.com/search?q={search_query}" return { 'description': caption, 'search_url': search_url } except: search_url = f"https://www.google.com/search?q={label.replace(' ', '+')}" return { 'description': f"A {label}", 'search_url': search_url } def create_detection_html(detections): """Create interactive HTML with clickable detections""" if not detections: return "No objects detected
" html = """