File size: 5,282 Bytes
31082a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# import base64 # Not needed
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import StreamingResponse
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import tensorflow as tf
import io
from typing import List

app = FastAPI(title="Fruit Classifier API")

# Load the model
# Assuming the model is in the same directory
MODEL_PATH = "fruit_classifier_model.h5"
try:
    model = tf.keras.models.load_model(MODEL_PATH)
    print(f"Model loaded successfully from {MODEL_PATH}")
except Exception as e:
    print(f"Error loading model: {e}")
    # We allow the app to start even if model fails, but predict will fail
    model = None

# Class names extracted from the training notebook
CLASS_NAMES = [
    'Apple', 'Apricots', 'Avocado', 'Banana', 'Blackberries', 'Blueberry', 
    'Cantaloupe', 'Cherry', 'Coconut', 'Dates', 'Dragon fruit', 'Fig', 
    'Grapes', 'Guava', 'Jackfruit', 'Kiwi', 'Lemons', 'Lychee', 'Mango', 
    'Olive', 'Orange', 'Papaya', 'Pear', 'Persimmon', 'Pineapple', 'Plum', 
    'Pomegranate', 'Rambutan', 'Raspberry', 'Salak', 'Sapodilla', 'Soursop', 
    'Starfruit', 'Strawberry', 'Watermelon'
]

def preprocess_image(image: Image.Image) -> np.ndarray:
    """
    Preprocess the image to match the model's expected input.
    EfficientNet usually expects (224, 224, 3) and values in [0, 255] 
    if using the internal preprocessing layer, or pre-scaled if not.
    The notebook showed:
    tf.keras.utils.image_dataset_from_directory(..., image_size=(224, 224), ...)
    and the model used Rescaling/Normalization layers inside it (efficientnetb0 usually has it or we saw Rescaling layer in summary).
    The provided summary showed: 
     rescaling (Rescaling) ...
     normalization (Normalization) ...
    So we just need to resize to (224, 224) and provide inputs as they are (0-255 usually for uint8, but converting to float32 is safer).
    """
    if image.mode != "RGB":
        image = image.convert("RGB")
    
    image_resized = image.resize((224, 224))
    image_array = np.array(image_resized)
    image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
    return image_array

@app.get("/")
def read_root():
    return {"message": "Welcome to the Fruit Classifier API"}

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    if model is None:
        raise HTTPException(status_code=500, detail="Model not loaded")
    
    try:
        contents = await file.read()
        image = Image.open(io.BytesIO(contents))
        processed_image = preprocess_image(image)
        
        predictions = model.predict(processed_image)
        predicted_class_index = np.argmax(predictions[0])
        confidence = float(predictions[0][predicted_class_index])
        predicted_class = CLASS_NAMES[predicted_class_index]
        
        return {
            "prediction": predicted_class,
            "confidence": confidence,
            "filename": file.filename
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/predict_image")
async def predict_image(file: UploadFile = File(...)):
    if model is None:
        raise HTTPException(status_code=500, detail="Model not loaded")
    
    try:
        contents = await file.read()
        original_image = Image.open(io.BytesIO(contents))
        
        # Ensure RGB
        if original_image.mode != "RGB":
            original_image = original_image.convert("RGB")
            
        processed_image = preprocess_image(original_image)
        
        predictions = model.predict(processed_image)
        predicted_class_index = np.argmax(predictions[0])
        confidence = float(predictions[0][predicted_class_index])
        predicted_class = CLASS_NAMES[predicted_class_index]
        
        # Draw on the original image
        draw = ImageDraw.Draw(original_image)
        
        # Try to load a nice font, otherwise default
        try:
            # Try loading a system font (Windows usually has arial)
            font = ImageFont.truetype("arial.ttf", size=int(original_image.height / 20))
        except IOError:
            font = ImageFont.load_default()
            
        text = f"{predicted_class} ({confidence:.2f})"
        
        # Calculate text position (top-left or centered-top)
        text_position = (10, 10)
        
        # Draw text with outline for better visibility
        x, y = text_position
        outline_color = "black"
        text_color = "red"
        
        draw.text((x-1, y-1), text, font=font, fill=outline_color)
        draw.text((x+1, y-1), text, font=font, fill=outline_color)
        draw.text((x-1, y+1), text, font=font, fill=outline_color)
        draw.text((x+1, y+1), text, font=font, fill=outline_color)
        draw.text(text_position, text, font=font, fill=text_color)
        
        # Save to bytes
        img_byte_arr = io.BytesIO()
        original_image.save(img_byte_arr, format='JPEG')
        img_byte_arr.seek(0)
        
        return StreamingResponse(img_byte_arr, media_type="image/jpeg")
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)