model / app.py
francmeister's picture
Update app.py
30b67b6 verified
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from inference import ObjectDetector
import numpy as np
import cv2
import socket
import uvicorn
# Configuration
MODEL_ONNX_PATH = "model.onnx"
CLASS_NAMES = [
'Butter_Dukat_Maslac_Stick_250g',
'Butter_Zbregov_Maslac_Stick_250g',
'Butter_Zdenka_Maslac_Stick_250g',
'Cheese_President_Gouda_Cube_250g',
'Chicken_Cekin_Pileca_Prsa_500g',
'Coffee_Franch_Crema_Bag_175g',
'Coffee_Franch_Crema_Box_250g',
'Coffee_Franch_Instant_Crema_80g',
'Coffee_Franch_Intense_Box_250g',
'Coffee_Franch_Original_Box_250g',
'Coffee_Franch_Sensual_Box_250g',
'Drink_CocaCola_Original_Bottle_1l',
'Flour_Mlineta_Brasno_Ostro_1kg',
'Juice_Vindi_Naranca_Nektar_1l',
'Ketchup_Zvijezda_Mild_Bottle_500g',
'Mayonnaise_Zvijezda_Delicate_Bottle_400g',
'Milk_Zbregov_Trajno_28_1l',
'Oil_Dijamant_Suncokretovo_Bottle_1l',
'Oil_Zvijezda_Suncokretovo_Ulje_1l',
'Pasta_Barilla_Fusilli_Box_500g',
'Rice_Gallo_Long_Grain_900g',
'Rice_Kplus_Arborio_BijeliDugi_1kg',
'Salt_SolanaPag_Sitna_Box_1kg',
'Spaghetti_PastaZara_Spaghettini_Bag_500g',
'Tuna_RioMare_Tonno_Oliva'
]
INPUT_SIZE = 640
# Initialize detector
detector = ObjectDetector(
model_path=MODEL_ONNX_PATH,
class_names=CLASS_NAMES,
input_size=INPUT_SIZE
)
# Initialize FastAPI
app = FastAPI()
# Enhanced CORS configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"]
)
def get_base_url():
hostname = socket.gethostname()
port = 7860 # Hugging Face Spaces uses port 7860
return f"https://{hostname}.hf.space"
@app.options("/detect")
async def detect_options():
return {"Allow": "POST"}
@app.get("/")
def health_check():
return {"status": "OK", "model": "Object Detection API"}
@app.post("/detect")
async def detect_objects(file: UploadFile = File(...)):
try:
if not file.content_type.startswith("image/"):
raise HTTPException(400, "File must be an image")
image_data = await file.read()
image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # <<< ADD THIS LINE
if image is None:
raise HTTPException(400, "Invalid image data")
# Remove RGB conversion - models expect BGR from OpenCV
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # DELETE THIS LINE
# Fix variable reference
detections = detector.predict(image) # Add this line
return {
"status": "success",
"detections": detections, # Use the variable
"count": len(detections) # Now properly defined
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(500, f"Processing error: {str(e)}")
if __name__ == "__main__":
base_url = get_base_url()
print(f"Base URL: {base_url}")
print(f"API endpoint: {base_url}/detect")
uvicorn.run(app, host="0.0.0.0", port=7860)