KurtHHHHHH commited on
Commit
5d48d9a
·
verified ·
1 Parent(s): d96d239

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -2
app.py CHANGED
@@ -15,7 +15,8 @@ from torchvision import transforms
15
  # 2. Load model weights
16
  # --------------------------
17
  # Load the checkpoint directly as it was saved (a plain ResNet50 with custom fc head)
18
- state_dict = torch.load("best_stanford_cars_transfer_model.pth", map_location="cpu")
 
19
 
20
  # Create a ResNet50 and modify its fc to match the checkpoint
21
  from torchvision.models import resnet50
@@ -241,10 +242,61 @@ labels = [
241
  "smart fortwo Convertible 2012"
242
  ]
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  # --------------------------
245
  # 4. Preprocessing function
246
  # --------------------------
247
  def preprocess_image(img: Image.Image):
 
 
248
  transform = transforms.Compose([
249
  transforms.Resize((224, 224)), # match your training input size
250
  transforms.ToTensor(),
@@ -253,7 +305,7 @@ def preprocess_image(img: Image.Image):
253
  std=[0.229, 0.224, 0.225]
254
  )
255
  ])
256
- x = transform(img).unsqueeze(0) # add batch dimension
257
  return x
258
 
259
  # --------------------------
 
15
  # 2. Load model weights
16
  # --------------------------
17
  # Load the checkpoint directly as it was saved (a plain ResNet50 with custom fc head)
18
+ # state_dict = torch.load("best_stanford_cars_transfer_model.pth", map_location="cpu")
19
+ state_dict = torch.load("test_with_YOLO.pth", map_location="cpu")
20
 
21
  # Create a ResNet50 and modify its fc to match the checkpoint
22
  from torchvision.models import resnet50
 
242
  "smart fortwo Convertible 2012"
243
  ]
244
 
245
+
246
+ from ultralytics import YOLO
247
+ import numpy as np
248
+
249
+ # --------------------------
250
+ # Load YOLO model for cropping
251
+ # --------------------------
252
+ device_str = 'cuda' if torch.cuda.is_available() else 'cpu'
253
+ yolo_model = YOLO('yolov8n.pt') # Using the small 'nano' model
254
+ print("YOLOv8 model loaded.")
255
+
256
+ # --------------------------
257
+ # Define YOLO cropping function
258
+ # --------------------------
259
+ def detect_and_crop_pil(pil_image, model=yolo_model, device=device_str, conf_thresh=0.25, pad_ratio=0.05):
260
+ """
261
+ Run YOLO on a PIL image and return a cropped PIL image around the best car detection.
262
+ If no car is found, it returns the original image.
263
+ """
264
+ results = model(pil_image, imgsz=640, conf=conf_thresh, device=device, verbose=False)
265
+ if len(results) == 0 or results[0].boxes is None or len(results[0].boxes) == 0:
266
+ return pil_image
267
+
268
+ r = results[0]
269
+ boxes = r.boxes.xyxy.cpu().numpy()
270
+ try:
271
+ classes = r.boxes.cls.cpu().numpy().astype(int)
272
+ except Exception:
273
+ classes = np.zeros(len(boxes), dtype=int)
274
+
275
+ # Prefer COCO car class (index 2)
276
+ car_indices = np.where(classes == 2)[0]
277
+ if len(car_indices) == 0:
278
+ return pil_image # Return original if no car detected
279
+
280
+ # Choose the car detection with the largest box area
281
+ areas = (boxes[car_indices, 2] - boxes[car_indices, 0]) * (boxes[car_indices, 3] - boxes[car_indices, 1])
282
+ best_idx = car_indices[np.argmax(areas)]
283
+ x1, y1, x2, y2 = boxes[best_idx].astype(int)
284
+
285
+ # Add padding
286
+ w, h = x2 - x1, y2 - y1
287
+ pad = int(max(w, h) * pad_ratio)
288
+ x1, y1 = max(0, x1 - pad), max(0, y1 - pad)
289
+ x2, y2 = min(pil_image.width, x2 + pad), min(pil_image.height, y2 + pad)
290
+
291
+ return pil_image.crop((x1, y1, x2, y2))
292
+
293
+
294
  # --------------------------
295
  # 4. Preprocessing function
296
  # --------------------------
297
  def preprocess_image(img: Image.Image):
298
+ cropped_img = detect_and_crop_pil(img)
299
+
300
  transform = transforms.Compose([
301
  transforms.Resize((224, 224)), # match your training input size
302
  transforms.ToTensor(),
 
305
  std=[0.229, 0.224, 0.225]
306
  )
307
  ])
308
+ x = transform(cropped_img).unsqueeze(0) # add batch dimension
309
  return x
310
 
311
  # --------------------------