abdrabo01 commited on
Commit
887bc13
·
verified ·
1 Parent(s): dc00e54

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +57 -30
main.py CHANGED
@@ -1,56 +1,83 @@
1
  from fastapi import FastAPI
 
2
  import uvicorn
3
  import base64
4
  import cv2
5
  import numpy as np
6
  from ultralytics import YOLO
7
  from datetime import datetime
8
- from pydantic import BaseModel
9
-
10
- app = FastAPI()
11
 
 
12
  model = YOLO("pcb_component_detection_best.pt")
13
 
14
  class ImageRequest(BaseModel):
 
15
  image: str
16
 
17
  @app.get("/")
18
  async def root():
 
19
  current_time = datetime.now().isoformat()
20
- return {"message": "PCB components API works", "time": current_time}
 
 
 
21
 
22
  @app.post("/predict")
23
  async def predict(request: ImageRequest):
24
- # Decode Base64
25
- image_bytes = base64.b64decode(request.image)
26
- np_arr = np.frombuffer(image_bytes, np.uint8)
27
- image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
28
- if image is None:
29
- return {"error": "Invalid image"}
30
 
31
- # Inference
32
- results = model.predict(image)
33
- result = results[0]
34
-
35
- # Response
36
- json_result = {}
37
- class_counters = {}
 
 
38
 
39
- for box in result.boxes:
40
- class_id = int(box.cls[0])
41
- class_name = result.names[class_id]
42
- x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
43
-
44
- if class_name in class_counters:
45
- class_counters[class_name] += 1
46
- key = f"{class_name}_{class_counters[class_name]}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  else:
48
- class_counters[class_name] = 1
49
- key = class_name
 
 
 
 
50
 
51
- json_result[key] = [x1, y1, x2, y2]
52
-
53
- return json_result
54
 
55
  if __name__ == "__main__":
56
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
  import uvicorn
4
  import base64
5
  import cv2
6
  import numpy as np
7
  from ultralytics import YOLO
8
  from datetime import datetime
 
 
 
9
 
10
+ app = FastAPI(title="PCB Component Detection API")
11
  model = YOLO("pcb_component_detection_best.pt")
12
 
13
  class ImageRequest(BaseModel):
14
+ """Request model for image processing endpoint."""
15
  image: str
16
 
17
  @app.get("/")
18
  async def root():
19
+ """Root endpoint to verify API status."""
20
  current_time = datetime.now().isoformat()
21
+ return {
22
+ "message": "PCB Components API works",
23
+ "time": current_time
24
+ }
25
 
26
  @app.post("/predict")
27
  async def predict(request: ImageRequest):
28
+ """
29
+ Process an image to detect PCB components.
 
 
 
 
30
 
31
+ Args:
32
+ request: Contains base64 encoded image
33
+
34
+ Returns:
35
+ JSON with detection statistics and bounding boxes
36
+ """
37
+ # Validate image input
38
+ if not request.image:
39
+ return {"error": "Invalid Image"}
40
 
41
+ try:
42
+ image_bytes = base64.b64decode(request.image, validate=True)
43
+
44
+ np_arr = np.frombuffer(image_bytes, np.uint8)
45
+ image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
46
+
47
+ if image is None:
48
+ return {"error": "Invalid image"}
49
+
50
+ results = model.predict(image)
51
+ result = results[0]
52
+
53
+ json_result = {}
54
+ class_counters = {}
55
+
56
+ for box in result.boxes:
57
+ class_id = int(box.cls[0])
58
+ class_name = result.names[class_id]
59
+ x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
60
+
61
+ if class_name in class_counters:
62
+ class_counters[class_name] += 1
63
+ else:
64
+ class_counters[class_name] = 1
65
+
66
+ key = f"{class_name}{class_counters[class_name]}" if class_counters[class_name] > 1 else class_name
67
+ json_result[key] = [x1, y1, x2, y2]
68
+
69
+ if hasattr(result, "summary") and isinstance(result.summary, dict):
70
+ statistics_summary = result.summary
71
  else:
72
+ statistics_summary = {name: count for name, count in class_counters.items()}
73
+
74
+ return {
75
+ "statistics": statistics_summary,
76
+ "components": json_result
77
+ }
78
 
79
+ except Exception as e:
80
+ return {"error": f"Invalid Image: {str(e)}"}
 
81
 
82
  if __name__ == "__main__":
83
  uvicorn.run(app, host="0.0.0.0", port=7860)