|
|
import gradio as gr |
|
|
from diffusers import StableDiffusionInstructPix2PixPipeline |
|
|
from transformers import YolosImageProcessor, YolosForObjectDetection, BlipProcessor, BlipForConditionalGeneration |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
import torch |
|
|
import json |
|
|
|
|
|
|
|
|
pipe = None |
|
|
detector = None |
|
|
detector_processor = None |
|
|
captioner = None |
|
|
caption_processor = None |
|
|
|
|
|
|
|
|
def generate_color(text): |
|
|
"""Generate consistent color from text using hash""" |
|
|
hash_val = hash(text) % 360 |
|
|
return f"hsl({hash_val}, 70%, 55%)" |
|
|
|
|
|
|
|
|
DETECTED_CATEGORIES = {} |
|
|
|
|
|
def load_models(): |
|
|
"""Load all models""" |
|
|
global pipe, detector, detector_processor, captioner, caption_processor |
|
|
|
|
|
if pipe is None: |
|
|
print("Loading image editor...") |
|
|
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( |
|
|
"timbrooks/instruct-pix2pix", |
|
|
torch_dtype=torch.float16, |
|
|
safety_checker=None |
|
|
) |
|
|
pipe.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
if detector is None: |
|
|
print("Loading object detector...") |
|
|
detector_processor = YolosImageProcessor.from_pretrained('hustvl/yolos-tiny') |
|
|
detector = YolosForObjectDetection.from_pretrained('hustvl/yolos-tiny') |
|
|
detector.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
if captioner is None: |
|
|
print("Loading image captioner...") |
|
|
caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
|
|
captioner = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
|
|
captioner.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
print("All models loaded!") |
|
|
|
|
|
def detect_objects(image): |
|
|
"""Detect objects in image with detailed info""" |
|
|
load_models() |
|
|
|
|
|
try: |
|
|
|
|
|
inputs = detector_processor(images=image, return_tensors="pt") |
|
|
if torch.cuda.is_available(): |
|
|
inputs = {k: v.to("cuda") for k, v in inputs.items()} |
|
|
|
|
|
outputs = detector(**inputs) |
|
|
target_sizes = torch.tensor([image.size[::-1]]) |
|
|
results = detector_processor.post_process_object_detection(outputs, threshold=0.3, target_sizes=target_sizes)[0] |
|
|
|
|
|
|
|
|
draw = ImageDraw.Draw(image) |
|
|
try: |
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
|
|
|
detections = [] |
|
|
|
|
|
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]): |
|
|
box = [round(i, 2) for i in box.tolist()] |
|
|
label_name = detector.config.id2label[label.item()] |
|
|
confidence = round(score.item(), 3) |
|
|
|
|
|
|
|
|
category = label_name |
|
|
color = generate_color(label_name) |
|
|
|
|
|
|
|
|
if category not in DETECTED_CATEGORIES: |
|
|
DETECTED_CATEGORIES[category] = color |
|
|
|
|
|
|
|
|
draw.rectangle(box, outline=color, width=3) |
|
|
|
|
|
|
|
|
text = f"{label_name} {confidence:.0%}" |
|
|
bbox = draw.textbbox((box[0], box[1]-20), text, font=font) |
|
|
draw.rectangle([bbox[0]-2, bbox[1]-2, bbox[2]+2, bbox[3]+2], fill=color) |
|
|
draw.text((box[0], box[1]-20), text, fill='white', font=font) |
|
|
|
|
|
|
|
|
obj_image = image.crop(box) |
|
|
obj_info = get_detailed_info(obj_image, label_name) |
|
|
|
|
|
detections.append({ |
|
|
'label': label_name, |
|
|
'category': category, |
|
|
'confidence': f"{confidence:.1%}", |
|
|
'bbox': box, |
|
|
'color': color, |
|
|
'details': obj_info |
|
|
}) |
|
|
|
|
|
|
|
|
html_output = create_detection_html(detections) |
|
|
|
|
|
return image, html_output, json.dumps(detections, indent=2) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Detection error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return image, f"<p>Error: {str(e)}</p>", "{}" |
|
|
|
|
|
def get_detailed_info(obj_image, label): |
|
|
"""Get detailed description of the detected object""" |
|
|
try: |
|
|
|
|
|
inputs = caption_processor(obj_image, return_tensors="pt") |
|
|
if torch.cuda.is_available(): |
|
|
inputs = {k: v.to("cuda") for k, v in inputs.items()} |
|
|
|
|
|
out = captioner.generate(**inputs, max_length=50) |
|
|
caption = caption_processor.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
search_query = f"{label} {caption}".replace(' ', '+') |
|
|
search_url = f"https://www.google.com/search?q={search_query}" |
|
|
|
|
|
return { |
|
|
'description': caption, |
|
|
'search_url': search_url |
|
|
} |
|
|
except: |
|
|
search_url = f"https://www.google.com/search?q={label.replace(' ', '+')}" |
|
|
return { |
|
|
'description': f"A {label}", |
|
|
'search_url': search_url |
|
|
} |
|
|
|
|
|
def create_detection_html(detections): |
|
|
"""Create interactive HTML with clickable detections""" |
|
|
if not detections: |
|
|
return "<p>No objects detected</p>" |
|
|
|
|
|
html = """ |
|
|
<style> |
|
|
.detection-container {font-family: Arial; padding: 10px;} |
|
|
.detection-item {margin: 15px 0; padding: 15px; border-radius: 8px; border-left: 5px solid; cursor: pointer; transition: transform 0.2s;} |
|
|
.detection-item:hover {transform: translateX(5px); box-shadow: 0 2px 8px rgba(0,0,0,0.1);} |
|
|
.object-label {font-size: 18px; font-weight: bold; margin-bottom: 5px;} |
|
|
.object-details {font-size: 14px; color: #555; margin: 5px 0;} |
|
|
.object-category {display: inline-block; padding: 3px 10px; border-radius: 12px; font-size: 12px; color: white; margin-right: 10px;} |
|
|
.search-link {color: #1a73e8; text-decoration: none; font-size: 13px;} |
|
|
.search-link:hover {text-decoration: underline;} |
|
|
</style> |
|
|
<div class="detection-container"> |
|
|
""" |
|
|
|
|
|
|
|
|
by_category = {} |
|
|
for det in detections: |
|
|
cat = det['category'] |
|
|
if cat not in by_category: |
|
|
by_category[cat] = [] |
|
|
by_category[cat].append(det) |
|
|
|
|
|
for category, items in by_category.items(): |
|
|
color = generate_color(category) |
|
|
html += f"<h3 style='color: {color}; text-transform: capitalize;'>{category}s ({len(items)})</h3>" |
|
|
|
|
|
for det in items: |
|
|
html += f""" |
|
|
<div class="detection-item" style="border-left-color: {det['color']}; background: {det['color']}15;" |
|
|
onclick="window.open('{det['details']['search_url']}', '_blank')"> |
|
|
<div class="object-label" style="color: {det['color']};">{det['label']}</div> |
|
|
<div class="object-details"> |
|
|
<span class="object-category" style="background: {det['color']};">{det['category']}</span> |
|
|
<span>Confidence: {det['confidence']}</span> |
|
|
</div> |
|
|
<div class="object-details">{det['details']['description']}</div> |
|
|
<a href="{det['details']['search_url']}" target="_blank" class="search-link" onclick="event.stopPropagation();"> |
|
|
π Learn more about this {det['label']} |
|
|
</a> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
html += "</div>" |
|
|
return html |
|
|
|
|
|
def edit_image(input_image, edit_prompt, num_steps, guidance_scale, image_guidance_scale): |
|
|
"""Edit image""" |
|
|
if input_image is None or not edit_prompt.strip(): |
|
|
return None, "β Provide image and prompt!" |
|
|
|
|
|
try: |
|
|
load_models() |
|
|
|
|
|
|
|
|
max_size = 512 |
|
|
if max(input_image.size) > max_size: |
|
|
ratio = max_size / max(input_image.size) |
|
|
new_size = tuple(int(dim * ratio) for dim in input_image.size) |
|
|
input_image = input_image.resize(new_size, Image.Resampling.LANCZOS) |
|
|
|
|
|
width = (input_image.width // 8) * 8 |
|
|
height = (input_image.height // 8) * 8 |
|
|
input_image = input_image.resize((width, height)) |
|
|
|
|
|
result = pipe( |
|
|
edit_prompt, |
|
|
image=input_image, |
|
|
num_inference_steps=num_steps, |
|
|
guidance_scale=guidance_scale, |
|
|
image_guidance_scale=image_guidance_scale, |
|
|
).images[0] |
|
|
|
|
|
return result, "β
Done!" |
|
|
|
|
|
except Exception as e: |
|
|
return None, f"β Error: {str(e)}" |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# π¨ AI Image Editor & Object Detector") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.Tab("π Detect Objects"): |
|
|
gr.Markdown("Upload an image to detect and identify objects with detailed information") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
detect_input = gr.Image(label="Upload Image", type="pil") |
|
|
detect_btn = gr.Button("π Detect Objects", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
detect_output = gr.Image(label="Detected Objects") |
|
|
|
|
|
detection_info = gr.HTML(label="Object Details (Click to learn more)") |
|
|
detection_json = gr.JSON(label="Detection Data", visible=False) |
|
|
|
|
|
detect_btn.click( |
|
|
fn=detect_objects, |
|
|
inputs=[detect_input], |
|
|
outputs=[detect_output, detection_info, detection_json] |
|
|
) |
|
|
|
|
|
with gr.Tab("βοΈ Edit Image"): |
|
|
gr.Markdown("Edit images with text instructions") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
edit_input = gr.Image(label="Upload Image", type="pil") |
|
|
edit_prompt = gr.Textbox( |
|
|
label="Instructions", |
|
|
placeholder="make it a painting, add snow, turn day into night...", |
|
|
lines=2 |
|
|
) |
|
|
with gr.Accordion("Settings", open=False): |
|
|
num_steps = gr.Slider(10, 50, value=20, step=5, label="Steps") |
|
|
guidance_scale = gr.Slider(1, 10, value=7.5, step=0.5, label="Text Guidance") |
|
|
image_guidance_scale = gr.Slider(1, 2, value=1.5, step=0.1, label="Image Guidance") |
|
|
edit_btn = gr.Button("β¨ Edit", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
edit_output = gr.Image(label="Result") |
|
|
edit_status = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
edit_btn.click( |
|
|
fn=edit_image, |
|
|
inputs=[edit_input, edit_prompt, num_steps, guidance_scale, image_guidance_scale], |
|
|
outputs=[edit_output, edit_status] |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
### π― Features: |
|
|
- **Object Detection**: Identifies objects with bounding boxes and confidence scores |
|
|
- **Categories**: Color-coded by type (vehicles, animals, people, etc.) |
|
|
- **Detailed Info**: AI-generated descriptions for each object |
|
|
- **Clickable Links**: Click any object to learn more about it |
|
|
- **Image Editing**: Transform images with simple text instructions |
|
|
""") |
|
|
|
|
|
demo.launch() |