Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from inference import ObjectDetector | |
| import numpy as np | |
| import cv2 | |
| import socket | |
| import uvicorn | |
| # Configuration | |
| MODEL_ONNX_PATH = "model.onnx" | |
| CLASS_NAMES = [ | |
| 'Butter_Dukat_Maslac_Stick_250g', | |
| 'Butter_Zbregov_Maslac_Stick_250g', | |
| 'Butter_Zdenka_Maslac_Stick_250g', | |
| 'Cheese_President_Gouda_Cube_250g', | |
| 'Chicken_Cekin_Pileca_Prsa_500g', | |
| 'Coffee_Franch_Crema_Bag_175g', | |
| 'Coffee_Franch_Crema_Box_250g', | |
| 'Coffee_Franch_Instant_Crema_80g', | |
| 'Coffee_Franch_Intense_Box_250g', | |
| 'Coffee_Franch_Original_Box_250g', | |
| 'Coffee_Franch_Sensual_Box_250g', | |
| 'Drink_CocaCola_Original_Bottle_1l', | |
| 'Flour_Mlineta_Brasno_Ostro_1kg', | |
| 'Juice_Vindi_Naranca_Nektar_1l', | |
| 'Ketchup_Zvijezda_Mild_Bottle_500g', | |
| 'Mayonnaise_Zvijezda_Delicate_Bottle_400g', | |
| 'Milk_Zbregov_Trajno_28_1l', | |
| 'Oil_Dijamant_Suncokretovo_Bottle_1l', | |
| 'Oil_Zvijezda_Suncokretovo_Ulje_1l', | |
| 'Pasta_Barilla_Fusilli_Box_500g', | |
| 'Rice_Gallo_Long_Grain_900g', | |
| 'Rice_Kplus_Arborio_BijeliDugi_1kg', | |
| 'Salt_SolanaPag_Sitna_Box_1kg', | |
| 'Spaghetti_PastaZara_Spaghettini_Bag_500g', | |
| 'Tuna_RioMare_Tonno_Oliva' | |
| ] | |
| INPUT_SIZE = 640 | |
| # Initialize detector | |
| detector = ObjectDetector( | |
| model_path=MODEL_ONNX_PATH, | |
| class_names=CLASS_NAMES, | |
| input_size=INPUT_SIZE | |
| ) | |
| # Initialize FastAPI | |
| app = FastAPI() | |
| # Enhanced CORS configuration | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| expose_headers=["*"] | |
| ) | |
| def get_base_url(): | |
| hostname = socket.gethostname() | |
| port = 7860 # Hugging Face Spaces uses port 7860 | |
| return f"https://{hostname}.hf.space" | |
| async def detect_options(): | |
| return {"Allow": "POST"} | |
| def health_check(): | |
| return {"status": "OK", "model": "Object Detection API"} | |
| async def detect_objects(file: UploadFile = File(...)): | |
| try: | |
| if not file.content_type.startswith("image/"): | |
| raise HTTPException(400, "File must be an image") | |
| image_data = await file.read() | |
| image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # <<< ADD THIS LINE | |
| if image is None: | |
| raise HTTPException(400, "Invalid image data") | |
| # Remove RGB conversion - models expect BGR from OpenCV | |
| # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # DELETE THIS LINE | |
| # Fix variable reference | |
| detections = detector.predict(image) # Add this line | |
| return { | |
| "status": "success", | |
| "detections": detections, # Use the variable | |
| "count": len(detections) # Now properly defined | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(500, f"Processing error: {str(e)}") | |
| if __name__ == "__main__": | |
| base_url = get_base_url() | |
| print(f"Base URL: {base_url}") | |
| print(f"API endpoint: {base_url}/detect") | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |