Janiopi's picture
Upload 4 files
2b6099a verified
from fastapi import FastAPI, File, UploadFile, HTTPException, Query
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from typing import List
import cv2
from PIL import Image
import numpy as np
from io import BytesIO
app = FastAPI()
def detect_cat(image, draw_rectangles=False):
existe = "NO"
print("resultado: ", image.shape)
# Load the cat face cascade classifier
cat_cascade = cv2.CascadeClassifier('haarcascade_frontalcatface.xml')
# Convert to grayscale
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Detect cat faces
cats = cat_cascade.detectMultiScale(
gray,
scaleFactor=1.1,
minNeighbors=3,
minSize=(30, 30)
)
# Draw rectangles around detected cats if requested
if draw_rectangles:
for (x, y, w, h) in cats:
cv2.rectangle(image, (x, y), (x+w, y+h), (0, 255, 0), 2)
# Return more detailed information
return {
"found": "SI" if len(cats) > 0 else "NO",
"count": len(cats),
"locations": cats.tolist() if len(cats) > 0 else []
}
@app.post('/predict/')
async def predict(
file: UploadFile = File(...),
tipo: str = Query(...),
draw_boxes: bool = Query(False)
):
try:
image = Image.open(BytesIO(await file.read()))
image = np.asarray(image)
prediction = detect_cat(image, draw_rectangles=draw_boxes)
if draw_boxes:
# Convert back to PIL Image and then to bytes
result_image = Image.fromarray(image)
img_byte_arr = BytesIO()
result_image.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
return {
"prediction": prediction,
"image": img_byte_arr
}
return {"prediction": prediction}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))