Inventory_Detect / main.py
VishnuCodes's picture
Update main.py
1677f87 verified
from fastapi import FastAPI, File, UploadFile, Response
from fastapi.responses import FileResponse
from ultralytics import YOLO
import cv2
import numpy as np
from PIL import Image
import io
import os
app = FastAPI()
uploads_dir = 'uploads'
if not os.path.exists(uploads_dir):
os.makedirs(uploads_dir)
# Load the YOLO model
yolo_model_path = 'best.pt'
yolo_model = YOLO(yolo_model_path)
@app.post("/detect_objects")
async def detect_objects(file: UploadFile = File(...)):
if file.filename.lower().endswith(('.png', '.jpg', '.jpeg')):
# Read the uploaded image
contents = await file.read()
nparr = np.frombuffer(contents, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# Perform object detection with YOLO
results = yolo_model(img)
# Initialize bounding box dimensions
combined_xmin = float('inf')
combined_ymin = float('inf')
combined_xmax = float('-inf')
combined_ymax = float('-inf')
# Process YOLO detections
for detection in results[0].boxes.xyxy.tolist():
xmin, ymin, xmax, ymax = detection
combined_xmin = min(combined_xmin, xmin)
combined_ymin = min(combined_ymin, ymin)
combined_xmax = max(combined_xmax, xmax)
combined_ymax = max(combined_ymax, ymax)
# Round bounding box values to integers
combined_xmin = int(combined_xmin)
combined_ymin = int(combined_ymin)
combined_xmax = int(combined_xmax)
combined_ymax = int(combined_ymax)
# Crop and annotate the image
annotated_img = results[0].plot() # Get annotated image as a numpy array
cropped_img = img[combined_ymin:combined_ymax, combined_xmin:combined_xmax]
# Convert the annotated image to bytes for response
annotated_pil_image = Image.fromarray(annotated_img)
img_byte_arr = io.BytesIO()
annotated_pil_image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
# Prepare metadata for the response
metadata = {
"X-Min": combined_xmin,
"Y-Min": combined_ymin,
"X-Max": combined_xmax,
"Y-Max": combined_ymax
}
# Create the response
response = Response(img_byte_arr.getvalue(), media_type="image/png")
for key, value in metadata.items():
response.headers[key] = str(value)
return response
return {"error": "Invalid file format"}