File size: 3,191 Bytes
5e7f154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import tensorflow as tf
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from io import BytesIO
from PIL import Image
import numpy as np
from typing import List
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

MODEL_PATH = "model.keras"
try:
    logger.info(f"Loading model from {MODEL_PATH}")
    model = tf.keras.models.load_model(MODEL_PATH)
except Exception as e:
    logger.error(f"Failed to load model: {e}")
    raise

CLASS_NAMES = ['healthy', 'rotten']

def process_image(image: Image.Image) -> np.ndarray:
    try:
        if image.mode != 'RGB':
            image = image.convert('RGB')
        image = image.resize((128, 128))
        image_array = np.array(image) / 255.0
        if image_array.shape[-1] == 4:
            image_array = image_array[..., :3]
        return image_array
    except Exception as e:
        logger.error(f"Error processing image: {e}")
        raise ValueError(f"Error processing image: {e}")

@app.post("/predict")
async def predict(files: List[UploadFile] = File(...)):
    try:
        if not files:
            logger.warning("No files provided in request")
            raise HTTPException(status_code=400, detail="No files provided")
        
        results = []
        for file in files:
            logger.info(f"Processing file: {file.filename}")
           
            if file.size > 10 * 1024 * 1024:
                logger.warning(f"File {file.filename} too large: {file.size} bytes")
                raise HTTPException(status_code=400, detail=f"File {file.filename} too large")
            
          
            content = await file.read()
            image = Image.open(BytesIO(content))
            image_array = process_image(image)
            image_array = np.expand_dims(image_array, axis=0)
            
          
            logger.info(f"Making prediction for {file.filename}")
            predictions = model.predict(image_array, verbose=0)
            predicted_index = np.argmax(predictions[0])
            predicted_class = CLASS_NAMES[predicted_index]
            confidence = float(np.max(predictions[0]))
            
            results.append({
                "filename": file.filename or f"image_{len(results)}.jpg",
                "class": predicted_class,
                "confidence": round(confidence, 4)
            })
        
        logger.info(f"Returning {len(results)} predictions")
        return JSONResponse({
            "predictions": results,
            "status": "success"
        })
    
    except Exception as e:
        logger.error(f"Error in predict endpoint: {e}")
        return JSONResponse(
            status_code=400,
            content={
                "predictions": [],
                "status": "error",
                "error": str(e)
            }
        )