Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -167,23 +167,23 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
| 167 |
image = image.convert('RGB')
|
| 168 |
|
| 169 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 170 |
-
print(f"Using device: {device}")
|
| 171 |
|
| 172 |
# Get original image DPI and size
|
| 173 |
original_dpi = image.info.get('dpi', (72, 72))
|
| 174 |
original_size = image.size
|
| 175 |
-
print(f"Image size: {original_size}")
|
| 176 |
|
| 177 |
# Calculate relative font size based on image dimensions
|
| 178 |
base_fontsize = min(original_size) / 40
|
| 179 |
|
| 180 |
-
print("Loading models...")
|
| 181 |
owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
|
| 182 |
owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
|
| 183 |
sam_processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
|
| 184 |
sam_model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-base").to(device)
|
| 185 |
|
| 186 |
-
print("Running object detection...")
|
| 187 |
inputs = owlv2_processor(text=[target_label], images=image, return_tensors="pt").to(device)
|
| 188 |
with torch.no_grad():
|
| 189 |
outputs = owlv2_model(**inputs)
|
|
@@ -204,10 +204,10 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
| 204 |
max_score = scores[max_score_idx].item()
|
| 205 |
|
| 206 |
if max_score > 0.2:
|
| 207 |
-
print("Processing detection results...")
|
| 208 |
box = results["boxes"][max_score_idx].cpu().numpy()
|
| 209 |
|
| 210 |
-
print("Running SAM model...")
|
| 211 |
# Convert image to numpy array if needed for SAM
|
| 212 |
if isinstance(image, Image.Image):
|
| 213 |
image_np = np.array(image)
|
|
@@ -215,7 +215,7 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
| 215 |
image_np = image
|
| 216 |
|
| 217 |
sam_inputs = sam_processor(
|
| 218 |
-
image_np,
|
| 219 |
input_boxes=[[[box[0], box[1], box[2], box[3]]]],
|
| 220 |
return_tensors="pt"
|
| 221 |
).to(device)
|
|
@@ -229,7 +229,7 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
| 229 |
sam_inputs["reshaped_input_sizes"].cpu()
|
| 230 |
)
|
| 231 |
|
| 232 |
-
print(f"Mask type: {type(masks)}, Mask shape: {len(masks)}")
|
| 233 |
mask = masks[0]
|
| 234 |
if isinstance(mask, torch.Tensor):
|
| 235 |
mask = mask.numpy()
|
|
@@ -266,24 +266,23 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
| 266 |
)
|
| 267 |
|
| 268 |
plt.axis('off')
|
| 269 |
-
|
| 270 |
-
print("Saving final image...")
|
| 271 |
try:
|
| 272 |
-
#
|
| 273 |
-
buf = io.BytesIO()
|
| 274 |
-
|
| 275 |
-
# Force figure to be in a format we can save
|
| 276 |
fig.canvas.draw()
|
| 277 |
|
| 278 |
-
# Get the
|
| 279 |
-
|
| 280 |
-
|
|
|
|
| 281 |
|
| 282 |
-
#
|
| 283 |
-
output_image = Image.fromarray(
|
| 284 |
|
| 285 |
-
# Resize if needed
|
| 286 |
-
output_image
|
|
|
|
| 287 |
|
| 288 |
# Save to final buffer
|
| 289 |
final_buf = io.BytesIO()
|
|
@@ -294,13 +293,17 @@ def process_image_detection(image, target_label, surprise_rating):
|
|
| 294 |
plt.close(fig)
|
| 295 |
|
| 296 |
return final_buf
|
| 297 |
-
|
| 298 |
except Exception as e:
|
| 299 |
-
print(f"Save error details: {str(e)}")
|
| 300 |
print(f"Figure type: {type(fig)}")
|
| 301 |
print(f"Canvas type: {type(fig.canvas)}")
|
| 302 |
raise
|
| 303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
def process_and_analyze(image):
|
| 306 |
if image is None:
|
|
|
|
| 167 |
image = image.convert('RGB')
|
| 168 |
|
| 169 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 170 |
+
print(f"Using device: {device}")
|
| 171 |
|
| 172 |
# Get original image DPI and size
|
| 173 |
original_dpi = image.info.get('dpi', (72, 72))
|
| 174 |
original_size = image.size
|
| 175 |
+
print(f"Image size: {original_size}")
|
| 176 |
|
| 177 |
# Calculate relative font size based on image dimensions
|
| 178 |
base_fontsize = min(original_size) / 40
|
| 179 |
|
| 180 |
+
print("Loading models...")
|
| 181 |
owlv2_processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16")
|
| 182 |
owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
|
| 183 |
sam_processor = AutoProcessor.from_pretrained("facebook/sam-vit-base")
|
| 184 |
sam_model = AutoModelForMaskGeneration.from_pretrained("facebook/sam-vit-base").to(device)
|
| 185 |
|
| 186 |
+
print("Running object detection...")
|
| 187 |
inputs = owlv2_processor(text=[target_label], images=image, return_tensors="pt").to(device)
|
| 188 |
with torch.no_grad():
|
| 189 |
outputs = owlv2_model(**inputs)
|
|
|
|
| 204 |
max_score = scores[max_score_idx].item()
|
| 205 |
|
| 206 |
if max_score > 0.2:
|
| 207 |
+
print("Processing detection results...")
|
| 208 |
box = results["boxes"][max_score_idx].cpu().numpy()
|
| 209 |
|
| 210 |
+
print("Running SAM model...")
|
| 211 |
# Convert image to numpy array if needed for SAM
|
| 212 |
if isinstance(image, Image.Image):
|
| 213 |
image_np = np.array(image)
|
|
|
|
| 215 |
image_np = image
|
| 216 |
|
| 217 |
sam_inputs = sam_processor(
|
| 218 |
+
image_np,
|
| 219 |
input_boxes=[[[box[0], box[1], box[2], box[3]]]],
|
| 220 |
return_tensors="pt"
|
| 221 |
).to(device)
|
|
|
|
| 229 |
sam_inputs["reshaped_input_sizes"].cpu()
|
| 230 |
)
|
| 231 |
|
| 232 |
+
print(f"Mask type: {type(masks)}, Mask shape: {len(masks)}")
|
| 233 |
mask = masks[0]
|
| 234 |
if isinstance(mask, torch.Tensor):
|
| 235 |
mask = mask.numpy()
|
|
|
|
| 266 |
)
|
| 267 |
|
| 268 |
plt.axis('off')
|
| 269 |
+
|
| 270 |
+
print("Saving final image...")
|
| 271 |
try:
|
| 272 |
+
# Force figure to be rendered
|
|
|
|
|
|
|
|
|
|
| 273 |
fig.canvas.draw()
|
| 274 |
|
| 275 |
+
# Get the RGBA buffer from the figure
|
| 276 |
+
w, h = fig.canvas.get_width_height()
|
| 277 |
+
buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
| 278 |
+
buf.shape = (h, w, 3)
|
| 279 |
|
| 280 |
+
# Create PIL Image from buffer
|
| 281 |
+
output_image = Image.fromarray(buf)
|
| 282 |
|
| 283 |
+
# Resize to original size if needed
|
| 284 |
+
if output_image.size != original_size:
|
| 285 |
+
output_image = output_image.resize(original_size, Image.Resampling.LANCZOS)
|
| 286 |
|
| 287 |
# Save to final buffer
|
| 288 |
final_buf = io.BytesIO()
|
|
|
|
| 293 |
plt.close(fig)
|
| 294 |
|
| 295 |
return final_buf
|
| 296 |
+
|
| 297 |
except Exception as e:
|
| 298 |
+
print(f"Save error details: {str(e)}")
|
| 299 |
print(f"Figure type: {type(fig)}")
|
| 300 |
print(f"Canvas type: {type(fig.canvas)}")
|
| 301 |
raise
|
| 302 |
|
| 303 |
+
except Exception as e:
|
| 304 |
+
print(f"Process image detection error: {str(e)}")
|
| 305 |
+
print(f"Error occurred at line {e.__traceback__.tb_lineno}")
|
| 306 |
+
raise
|
| 307 |
|
| 308 |
def process_and_analyze(image):
|
| 309 |
if image is None:
|