minor-project / app.py
PiyushGPT's picture
Update app.py
bc7fe11 verified
import os
import torch
from PIL import Image, ImageDraw
from transformers import OwlViTProcessor, OwlViTForObjectDetection
import gradio as gr
# Set environment variables
os.environ["TORCHDYNAMO_DISABLE"] = "1"
# Global variables for model and processor
model = None
processor = None
# Load Model and Processor
def load_model():
"""Load OwlViT model and processor from local directory or Hugging Face Hub."""
global model, processor
if model is not None and processor is not None:
return model, processor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "google/owlvit-base-patch32"
# Check if local model directory exists
local_model_path = "./owlvit-base-patch32"
try:
if os.path.exists(local_model_path) and os.path.isdir(local_model_path):
print(f"Loading model from local directory: {local_model_path}")
processor = OwlViTProcessor.from_pretrained(local_model_path)
model = OwlViTForObjectDetection.from_pretrained(local_model_path)
else:
print(f"Loading model from Hugging Face Hub: {model_name}")
processor = OwlViTProcessor.from_pretrained(model_name)
model = OwlViTForObjectDetection.from_pretrained(model_name)
model.eval()
model.to(device)
print("Model loaded successfully!")
return model, processor
except Exception as e:
raise RuntimeError(f"Failed to load model: {str(e)}")
# Draw Bounding Boxes Function
def draw_boxes(image, results, queries):
"""Draw bounding boxes on the image."""
draw = ImageDraw.Draw(image)
boxes = results[0]["boxes"]
scores = results[0]["scores"]
labels = results[0]["labels"]
for box, score, label in zip(boxes, scores, labels):
x1, y1, x2, y2 = box.tolist()
# Draw rectangle
draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
# Draw label
text = f"{queries[label]}: {score:.2f}"
draw.text((x1, y1 - 15), text, fill="red")
return image
# Prediction Function
def detect_objects(image, text_query, threshold):
global model, processor
if image is None:
return None
try:
# Load model if not already loaded
if model is None or processor is None:
model, processor = load_model()
# Convert to PIL Image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image).convert("RGB")
else:
image = image.convert("RGB")
# Parse text queries (split by comma)
text_queries = [q.strip() for q in text_query.split(",") if q.strip()]
if not text_queries:
return image
# Process inputs
inputs = processor(text=text_queries, images=image, return_tensors="pt")
# Move inputs to device
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Run inference
with torch.no_grad():
outputs = model(**inputs)
# Post-process results
target_sizes = torch.Tensor([image.size[::-1]])
results = processor.post_process_object_detection(
outputs=outputs,
threshold=threshold,
target_sizes=target_sizes
)
# Draw bounding boxes
output_image = draw_boxes(image.copy(), results, text_queries)
return output_image
except Exception as e:
print(f"Error during detection: {str(e)}")
return image
# Gradio Interface
with gr.Blocks(title="Query based object detection") as demo:
gr.Markdown(
"""
Upload an image and describe what you want to detect. You can specify multiple objects separated by commas.
**Example queries:**
- `a dog on couch sofa`
- `person, car, bicycle`
- `red apple, green apple`
"""
)
with gr.Row():
with gr.Column():
image_input = gr.Image(
label="Upload Image",
type="pil",
height=400
)
text_input = gr.Textbox(
label="Text Query",
placeholder="e.g., a dog on couch sofa",
value="a dog on couch sofa"
)
threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
value=0.1,
step=0.05,
info="Lower values detect more objects but may include false positives"
)
detect_btn = gr.Button("Detect Objects", variant="primary")
with gr.Column():
output_image = gr.Image(
label="Detected Objects",
type="pil",
height=400
)
# Example queries
gr.Markdown("### Examples")
gr.Examples(
examples=[
["a dog on couch sofa", 0.1],
["person, car", 0.1],
["cat, dog", 0.1],
],
inputs=[text_input, threshold],
label="Try these queries"
)
# Set up the function call
detect_btn.click(
fn=detect_objects,
inputs=[image_input, text_input, threshold],
outputs=output_image
)
# Also allow Enter key to trigger detection
text_input.submit(
fn=detect_objects,
inputs=[image_input, text_input, threshold],
outputs=output_image
)
demo.launch()