Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, Request | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Load your trained model | |
| model = tf.keras.models.load_model('recyclebot.keras') | |
| # Define class names for predictions (this should be the same as in your local code) | |
| CLASSES = ['Glass', 'Metal', 'Paperboard', 'Plastic-Polystyrene', 'Plastic-Regular'] | |
| # Create FastAPI app | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allow all origins (or specify specific origins) | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allow all HTTP methods | |
| allow_headers=["*"], # Allow all headers | |
| ) | |
| # Preprocess the image (resize, reshape without normalization) | |
| def preprocess_image(image_file): | |
| try: | |
| # Load image using PIL | |
| image = Image.open(image_file) | |
| # Convert image to numpy array | |
| image = np.array(image) | |
| # Resize to the input shape expected by the model | |
| image = cv2.resize(image, (240, 240)) # Resize image to match model input | |
| # Reshape the image (similar to your local code) | |
| image = image.reshape(-1, 240, 240, 3) # Add the batch dimension for inference | |
| return image | |
| except Exception as e: | |
| logger.error(f"Error in preprocess_image: {str(e)}") | |
| raise | |
| # Background removal function | |
| async def predict(file: UploadFile = File(...)): | |
| try: | |
| logger.info("Received request for /predict") | |
| img_array = preprocess_image(file.file) # Preprocess the image | |
| prediction1 = model.predict(img_array) # Get predictions | |
| predicted_class_idx = np.argmax(prediction1, axis=1)[0] # Get predicted class index | |
| predicted_class = CLASSES[predicted_class_idx] # Convert to class name | |
| return JSONResponse(content={"prediction": predicted_class}) | |
| except Exception as e: | |
| logger.error(f"Error in /predict: {str(e)}") | |
| return JSONResponse(content={"error": str(e)}, status_code=400) | |
| async def working(): | |
| return JSONResponse(content={"Status": "Working"}) | |
| # To manually run FastAPI | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |