PiyushGPT commited on
Commit
65d162f
·
verified ·
1 Parent(s): 0fc3447

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +197 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image, ImageDraw
4
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection
5
+ import gradio as gr
6
+
7
+ # Set environment variables
8
+ os.environ["TORCHDYNAMO_DISABLE"] = "1"
9
+
10
+ # Global variables for model and processor
11
+ model = None
12
+ processor = None
13
+
14
+ # Load Model and Processor
15
+ def load_model():
16
+ """Load OwlViT model and processor from local directory or Hugging Face Hub."""
17
+ global model, processor
18
+
19
+ if model is not None and processor is not None:
20
+ return model, processor
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ model_name = "google/owlvit-base-patch32"
24
+
25
+ # Check if local model directory exists
26
+ local_model_path = "./owlvit-base-patch32"
27
+
28
+ try:
29
+ if os.path.exists(local_model_path) and os.path.isdir(local_model_path):
30
+ print(f"Loading model from local directory: {local_model_path}")
31
+ processor = OwlViTProcessor.from_pretrained(local_model_path)
32
+ model = OwlViTForObjectDetection.from_pretrained(local_model_path)
33
+ else:
34
+ print(f"Loading model from Hugging Face Hub: {model_name}")
35
+ processor = OwlViTProcessor.from_pretrained(model_name)
36
+ model = OwlViTForObjectDetection.from_pretrained(model_name)
37
+
38
+ model.eval()
39
+ model.to(device)
40
+ print("Model loaded successfully!")
41
+ return model, processor
42
+
43
+ except Exception as e:
44
+ raise RuntimeError(f"Failed to load model: {str(e)}")
45
+
46
+ # Draw Bounding Boxes Function
47
+ def draw_boxes(image, results, queries):
48
+ """Draw bounding boxes on the image."""
49
+ draw = ImageDraw.Draw(image)
50
+ boxes = results[0]["boxes"]
51
+ scores = results[0]["scores"]
52
+ labels = results[0]["labels"]
53
+
54
+ for box, score, label in zip(boxes, scores, labels):
55
+ x1, y1, x2, y2 = box.tolist()
56
+ # Draw rectangle
57
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
58
+ # Draw label
59
+ text = f"{queries[label]}: {score:.2f}"
60
+ draw.text((x1, y1 - 15), text, fill="red")
61
+ return image
62
+
63
+ # Prediction Function
64
+ def detect_objects(image, text_query, threshold):
65
+ """
66
+ Detect objects in image based on text query.
67
+
68
+ Args:
69
+ image: PIL Image or numpy array
70
+ text_query: Comma-separated text queries (e.g., "dog, cat, person")
71
+ threshold: Detection confidence threshold
72
+
73
+ Returns:
74
+ Annotated image with bounding boxes
75
+ """
76
+ global model, processor
77
+
78
+ if image is None:
79
+ return None
80
+
81
+ try:
82
+ # Load model if not already loaded
83
+ if model is None or processor is None:
84
+ model, processor = load_model()
85
+
86
+ # Convert to PIL Image if needed
87
+ if not isinstance(image, Image.Image):
88
+ image = Image.fromarray(image).convert("RGB")
89
+ else:
90
+ image = image.convert("RGB")
91
+
92
+ # Parse text queries (split by comma)
93
+ text_queries = [q.strip() for q in text_query.split(",") if q.strip()]
94
+
95
+ if not text_queries:
96
+ return image
97
+
98
+ # Process inputs
99
+ inputs = processor(text=text_queries, images=image, return_tensors="pt")
100
+
101
+ # Move inputs to device
102
+ device = next(model.parameters()).device
103
+ inputs = {k: v.to(device) for k, v in inputs.items()}
104
+
105
+ # Run inference
106
+ with torch.no_grad():
107
+ outputs = model(**inputs)
108
+
109
+ # Post-process results
110
+ target_sizes = torch.Tensor([image.size[::-1]])
111
+ results = processor.post_process_object_detection(
112
+ outputs=outputs,
113
+ threshold=threshold,
114
+ target_sizes=target_sizes
115
+ )
116
+
117
+ # Draw bounding boxes
118
+ output_image = draw_boxes(image.copy(), results, text_queries)
119
+
120
+ return output_image
121
+
122
+ except Exception as e:
123
+ print(f"Error during detection: {str(e)}")
124
+ return image
125
+
126
+ # Gradio Interface
127
+ with gr.Blocks(title="OwlViT Object Detection", theme=gr.themes.Soft()) as demo:
128
+ gr.Markdown(
129
+ """
130
+ # 🦉 OwlViT Object Detection
131
+
132
+ Upload an image and describe what you want to detect. You can specify multiple objects separated by commas.
133
+
134
+ **Example queries:**
135
+ - `a dog on couch sofa`
136
+ - `person, car, bicycle`
137
+ - `red apple, green apple`
138
+ """
139
+ )
140
+
141
+ with gr.Row():
142
+ with gr.Column():
143
+ image_input = gr.Image(
144
+ label="Upload Image",
145
+ type="pil",
146
+ height=400
147
+ )
148
+ text_input = gr.Textbox(
149
+ label="Text Query",
150
+ placeholder="e.g., a dog on couch sofa",
151
+ value="a dog on couch sofa"
152
+ )
153
+ threshold = gr.Slider(
154
+ label="Confidence Threshold",
155
+ minimum=0.0,
156
+ maximum=1.0,
157
+ value=0.1,
158
+ step=0.05,
159
+ info="Lower values detect more objects but may include false positives"
160
+ )
161
+ detect_btn = gr.Button("Detect Objects", variant="primary")
162
+
163
+ with gr.Column():
164
+ output_image = gr.Image(
165
+ label="Detected Objects",
166
+ type="pil",
167
+ height=400
168
+ )
169
+
170
+ # Example queries
171
+ gr.Markdown("### Examples")
172
+ gr.Examples(
173
+ examples=[
174
+ ["a dog on couch sofa", 0.1],
175
+ ["person, car", 0.1],
176
+ ["cat, dog", 0.1],
177
+ ],
178
+ inputs=[text_input, threshold],
179
+ label="Try these queries"
180
+ )
181
+
182
+ # Set up the function call
183
+ detect_btn.click(
184
+ fn=detect_objects,
185
+ inputs=[image_input, text_input, threshold],
186
+ outputs=output_image
187
+ )
188
+
189
+ # Also allow Enter key to trigger detection
190
+ text_input.submit(
191
+ fn=detect_objects,
192
+ inputs=[image_input, text_input, threshold],
193
+ outputs=output_image
194
+ )
195
+
196
+ demo.launch()
197
+
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ scipy==1.12.0
2
+ torch==2.3.0
3
+ torchvision==0.18.0
4
+ transformers>=4.35.0
5
+ pillow>=9.0
6
+ timm>=0.9.0
7
+ numpy>=1.21.0
8
+ requests>=2.25.0
9
+ gradio>=4.0.0
10
+