Spaces:
Sleeping
Sleeping
| # import base64 # Not needed | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import StreamingResponse | |
| from PIL import Image, ImageDraw, ImageFont | |
| import numpy as np | |
| import tensorflow as tf | |
| import io | |
| from typing import List | |
| app = FastAPI(title="Fruit Classifier API") | |
| # Load the model | |
| # Assuming the model is in the same directory | |
| MODEL_PATH = "fruit_classifier_model.h5" | |
| try: | |
| model = tf.keras.models.load_model(MODEL_PATH) | |
| print(f"Model loaded successfully from {MODEL_PATH}") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| # We allow the app to start even if model fails, but predict will fail | |
| model = None | |
| # Class names extracted from the training notebook | |
| CLASS_NAMES = [ | |
| 'Apple', 'Apricots', 'Avocado', 'Banana', 'Blackberries', 'Blueberry', | |
| 'Cantaloupe', 'Cherry', 'Coconut', 'Dates', 'Dragon fruit', 'Fig', | |
| 'Grapes', 'Guava', 'Jackfruit', 'Kiwi', 'Lemons', 'Lychee', 'Mango', | |
| 'Olive', 'Orange', 'Papaya', 'Pear', 'Persimmon', 'Pineapple', 'Plum', | |
| 'Pomegranate', 'Rambutan', 'Raspberry', 'Salak', 'Sapodilla', 'Soursop', | |
| 'Starfruit', 'Strawberry', 'Watermelon' | |
| ] | |
| def preprocess_image(image: Image.Image) -> np.ndarray: | |
| """ | |
| Preprocess the image to match the model's expected input. | |
| EfficientNet usually expects (224, 224, 3) and values in [0, 255] | |
| if using the internal preprocessing layer, or pre-scaled if not. | |
| The notebook showed: | |
| tf.keras.utils.image_dataset_from_directory(..., image_size=(224, 224), ...) | |
| and the model used Rescaling/Normalization layers inside it (efficientnetb0 usually has it or we saw Rescaling layer in summary). | |
| The provided summary showed: | |
| rescaling (Rescaling) ... | |
| normalization (Normalization) ... | |
| So we just need to resize to (224, 224) and provide inputs as they are (0-255 usually for uint8, but converting to float32 is safer). | |
| """ | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| image_resized = image.resize((224, 224)) | |
| image_array = np.array(image_resized) | |
| image_array = np.expand_dims(image_array, axis=0) # Add batch dimension | |
| return image_array | |
| def read_root(): | |
| return {"message": "Welcome to the Fruit Classifier API"} | |
| async def predict(file: UploadFile = File(...)): | |
| if model is None: | |
| raise HTTPException(status_code=500, detail="Model not loaded") | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| processed_image = preprocess_image(image) | |
| predictions = model.predict(processed_image) | |
| predicted_class_index = np.argmax(predictions[0]) | |
| confidence = float(predictions[0][predicted_class_index]) | |
| predicted_class = CLASS_NAMES[predicted_class_index] | |
| return { | |
| "prediction": predicted_class, | |
| "confidence": confidence, | |
| "filename": file.filename | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def predict_image(file: UploadFile = File(...)): | |
| if model is None: | |
| raise HTTPException(status_code=500, detail="Model not loaded") | |
| try: | |
| contents = await file.read() | |
| original_image = Image.open(io.BytesIO(contents)) | |
| # Ensure RGB | |
| if original_image.mode != "RGB": | |
| original_image = original_image.convert("RGB") | |
| processed_image = preprocess_image(original_image) | |
| predictions = model.predict(processed_image) | |
| predicted_class_index = np.argmax(predictions[0]) | |
| confidence = float(predictions[0][predicted_class_index]) | |
| predicted_class = CLASS_NAMES[predicted_class_index] | |
| # Draw on the original image | |
| draw = ImageDraw.Draw(original_image) | |
| # Try to load a nice font, otherwise default | |
| try: | |
| # Try loading a system font (Windows usually has arial) | |
| font = ImageFont.truetype("arial.ttf", size=int(original_image.height / 20)) | |
| except IOError: | |
| font = ImageFont.load_default() | |
| text = f"{predicted_class} ({confidence:.2f})" | |
| # Calculate text position (top-left or centered-top) | |
| text_position = (10, 10) | |
| # Draw text with outline for better visibility | |
| x, y = text_position | |
| outline_color = "black" | |
| text_color = "red" | |
| draw.text((x-1, y-1), text, font=font, fill=outline_color) | |
| draw.text((x+1, y-1), text, font=font, fill=outline_color) | |
| draw.text((x-1, y+1), text, font=font, fill=outline_color) | |
| draw.text((x+1, y+1), text, font=font, fill=outline_color) | |
| draw.text(text_position, text, font=font, fill=text_color) | |
| # Save to bytes | |
| img_byte_arr = io.BytesIO() | |
| original_image.save(img_byte_arr, format='JPEG') | |
| img_byte_arr.seek(0) | |
| return StreamingResponse(img_byte_arr, media_type="image/jpeg") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |