VishnuCodes commited on
Commit
1677f87
·
verified ·
1 Parent(s): e682c18

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +65 -20
main.py CHANGED
@@ -1,30 +1,75 @@
1
  from fastapi import FastAPI, File, UploadFile, Response
 
 
 
 
2
  from PIL import Image
3
- import torch
4
  import io
 
5
 
6
  app = FastAPI()
 
7
 
8
- # Load the YOLOv5 model
9
- model = torch.hub.load('ultralytics/yolov5', 'custom', path='best.pt')
10
 
11
- @app.post("/detect/")
12
- async def detect(file: UploadFile = File(...)):
13
- # Read image file
14
- image_data = await file.read()
15
- image = Image.open(io.BytesIO(image_data))
16
 
17
- # Perform inference
18
- results = model(image)
 
 
 
 
 
19
 
20
- # Render the results on the image
21
- results.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Convert the image to bytes
24
- img_bytes = io.BytesIO()
25
- image_with_boxes = Image.fromarray(results.ims[0])
26
- image_with_boxes.save(img_bytes, format='JPEG')
27
- img_bytes.seek(0)
28
-
29
- # Create a response with the image
30
- return Response(content=img_bytes.getvalue(), media_type="image/jpeg")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, File, UploadFile, Response
2
+ from fastapi.responses import FileResponse
3
+ from ultralytics import YOLO
4
+ import cv2
5
+ import numpy as np
6
  from PIL import Image
 
7
  import io
8
+ import os
9
 
10
  app = FastAPI()
11
+ uploads_dir = 'uploads'
12
 
13
+ if not os.path.exists(uploads_dir):
14
+ os.makedirs(uploads_dir)
15
 
16
+ # Load the YOLO model
17
+ yolo_model_path = 'best.pt'
18
+ yolo_model = YOLO(yolo_model_path)
 
 
19
 
20
+ @app.post("/detect_objects")
21
+ async def detect_objects(file: UploadFile = File(...)):
22
+ if file.filename.lower().endswith(('.png', '.jpg', '.jpeg')):
23
+ # Read the uploaded image
24
+ contents = await file.read()
25
+ nparr = np.frombuffer(contents, np.uint8)
26
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
27
 
28
+ # Perform object detection with YOLO
29
+ results = yolo_model(img)
30
+
31
+ # Initialize bounding box dimensions
32
+ combined_xmin = float('inf')
33
+ combined_ymin = float('inf')
34
+ combined_xmax = float('-inf')
35
+ combined_ymax = float('-inf')
36
+
37
+ # Process YOLO detections
38
+ for detection in results[0].boxes.xyxy.tolist():
39
+ xmin, ymin, xmax, ymax = detection
40
+ combined_xmin = min(combined_xmin, xmin)
41
+ combined_ymin = min(combined_ymin, ymin)
42
+ combined_xmax = max(combined_xmax, xmax)
43
+ combined_ymax = max(combined_ymax, ymax)
44
+
45
+ # Round bounding box values to integers
46
+ combined_xmin = int(combined_xmin)
47
+ combined_ymin = int(combined_ymin)
48
+ combined_xmax = int(combined_xmax)
49
+ combined_ymax = int(combined_ymax)
50
 
51
+ # Crop and annotate the image
52
+ annotated_img = results[0].plot() # Get annotated image as a numpy array
53
+ cropped_img = img[combined_ymin:combined_ymax, combined_xmin:combined_xmax]
54
+
55
+ # Convert the annotated image to bytes for response
56
+ annotated_pil_image = Image.fromarray(annotated_img)
57
+ img_byte_arr = io.BytesIO()
58
+ annotated_pil_image.save(img_byte_arr, format='PNG')
59
+ img_byte_arr.seek(0)
60
+
61
+ # Prepare metadata for the response
62
+ metadata = {
63
+ "X-Min": combined_xmin,
64
+ "Y-Min": combined_ymin,
65
+ "X-Max": combined_xmax,
66
+ "Y-Max": combined_ymax
67
+ }
68
+
69
+ # Create the response
70
+ response = Response(img_byte_arr.getvalue(), media_type="image/png")
71
+ for key, value in metadata.items():
72
+ response.headers[key] = str(value)
73
+
74
+ return response
75
+ return {"error": "Invalid file format"}