Spaces:
Sleeping
Sleeping
File size: 5,282 Bytes
31082a9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | # import base64 # Not needed
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import tensorflow as tf
import io
from typing import List
app = FastAPI(title="Fruit Classifier API")
# Load the model
# Assuming the model is in the same directory
MODEL_PATH = "fruit_classifier_model.h5"
try:
model = tf.keras.models.load_model(MODEL_PATH)
print(f"Model loaded successfully from {MODEL_PATH}")
except Exception as e:
print(f"Error loading model: {e}")
# We allow the app to start even if model fails, but predict will fail
model = None
# Class names extracted from the training notebook
CLASS_NAMES = [
'Apple', 'Apricots', 'Avocado', 'Banana', 'Blackberries', 'Blueberry',
'Cantaloupe', 'Cherry', 'Coconut', 'Dates', 'Dragon fruit', 'Fig',
'Grapes', 'Guava', 'Jackfruit', 'Kiwi', 'Lemons', 'Lychee', 'Mango',
'Olive', 'Orange', 'Papaya', 'Pear', 'Persimmon', 'Pineapple', 'Plum',
'Pomegranate', 'Rambutan', 'Raspberry', 'Salak', 'Sapodilla', 'Soursop',
'Starfruit', 'Strawberry', 'Watermelon'
]
def preprocess_image(image: Image.Image) -> np.ndarray:
"""
Preprocess the image to match the model's expected input.
EfficientNet usually expects (224, 224, 3) and values in [0, 255]
if using the internal preprocessing layer, or pre-scaled if not.
The notebook showed:
tf.keras.utils.image_dataset_from_directory(..., image_size=(224, 224), ...)
and the model used Rescaling/Normalization layers inside it (efficientnetb0 usually has it or we saw Rescaling layer in summary).
The provided summary showed:
rescaling (Rescaling) ...
normalization (Normalization) ...
So we just need to resize to (224, 224) and provide inputs as they are (0-255 usually for uint8, but converting to float32 is safer).
"""
if image.mode != "RGB":
image = image.convert("RGB")
image_resized = image.resize((224, 224))
image_array = np.array(image_resized)
image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
return image_array
@app.get("/")
def read_root():
return {"message": "Welcome to the Fruit Classifier API"}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents))
processed_image = preprocess_image(image)
predictions = model.predict(processed_image)
predicted_class_index = np.argmax(predictions[0])
confidence = float(predictions[0][predicted_class_index])
predicted_class = CLASS_NAMES[predicted_class_index]
return {
"prediction": predicted_class,
"confidence": confidence,
"filename": file.filename
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict_image")
async def predict_image(file: UploadFile = File(...)):
if model is None:
raise HTTPException(status_code=500, detail="Model not loaded")
try:
contents = await file.read()
original_image = Image.open(io.BytesIO(contents))
# Ensure RGB
if original_image.mode != "RGB":
original_image = original_image.convert("RGB")
processed_image = preprocess_image(original_image)
predictions = model.predict(processed_image)
predicted_class_index = np.argmax(predictions[0])
confidence = float(predictions[0][predicted_class_index])
predicted_class = CLASS_NAMES[predicted_class_index]
# Draw on the original image
draw = ImageDraw.Draw(original_image)
# Try to load a nice font, otherwise default
try:
# Try loading a system font (Windows usually has arial)
font = ImageFont.truetype("arial.ttf", size=int(original_image.height / 20))
except IOError:
font = ImageFont.load_default()
text = f"{predicted_class} ({confidence:.2f})"
# Calculate text position (top-left or centered-top)
text_position = (10, 10)
# Draw text with outline for better visibility
x, y = text_position
outline_color = "black"
text_color = "red"
draw.text((x-1, y-1), text, font=font, fill=outline_color)
draw.text((x+1, y-1), text, font=font, fill=outline_color)
draw.text((x-1, y+1), text, font=font, fill=outline_color)
draw.text((x+1, y+1), text, font=font, fill=outline_color)
draw.text(text_position, text, font=font, fill=text_color)
# Save to bytes
img_byte_arr = io.BytesIO()
original_image.save(img_byte_arr, format='JPEG')
img_byte_arr.seek(0)
return StreamingResponse(img_byte_arr, media_type="image/jpeg")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|