Sifrac-ML / app.py
ncardian's picture
Update app.py
3ba665f verified
import os
from typing import Optional
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import JSONResponse
import numpy as np
from PIL import Image
from io import BytesIO
from huggingface_hub import login, hf_hub_download
import keras
import scipy.io
import scipy.ndimage
# Initialize FastAPI app
app = FastAPI(title="PETRA API", description="API PETRA, Input Shape (200, 200, 3)")
# Login to Hugging Face (if needed)
if "HF_TOKENS" in os.environ:
login(token=os.environ.get("HF_TOKENS"))
# Set backend (optional)
os.environ["KERAS_BACKEND"] = "tensorflow"
# Define expected input shape
INPUT_SHAPE = (200, 200, 3)
# Load model from Hugging Face Hub
try:
# Option 1: If saved with keras.saving.save_model()
#model = keras.saving.load_model("hf://ncardian/petra")
# OR Option 2: If you have specific files
model_path = hf_hub_download(repo_id="drprs-research/sifrac-ml", filename="sifract.keras")
model = keras.models.load_model(model_path)
except Exception as e:
raise RuntimeError(f"Error loading model: {str(e)}")
def preprocess_image(image: Image.Image) -> np.ndarray:
"""
Preprocess the uploaded image to match model input requirements.
"""
# Load file .mat
mat_data = scipy.io.loadmat(image)
data_keys = [key for key in mat_data.keys() if not key.startswith('__')]
wadah = mat_data[data_keys[0]]
wadah_Z = wadah.shape
target_size = (200, 200, 200) # Misalnya ingin diresize ke ukuran (100, 100, 100)
# Hitung faktor skala untuk masing-masing dimensi
resize_factor = [target_size[i] / wadah.shape[i] for i in range(3)]
# Resize data ke ukuran target
wadah200 = scipy.ndimage.zoom(wadah, resize_factor, order=0)
X, Y, Z = wadah200.shape
# Proses data
sliceXY = np.zeros((X, Y)) * 255
sliceYZ = np.transpose(wadah200[X//2, :, :] * 255) # 100th slice along the Z-axis (index 99)
sliceXZ = wadah200[:, Y//2, :] * 255 # 100th slice along the Y-axis (index 99)
# Buat gambar RGB
rgbImage = np.zeros((X, Y, 3), dtype=np.uint8)
rgbImage[:, :, 0] = sliceXY.astype(np.uint8)
rgbImage[:, :, 1] = sliceYZ.astype(np.uint8)
rgbImage[:, :, 2] = sliceXZ.astype(np.uint8)
# Simpan gambar
image = Image.fromarray(rgbImage)
# Resize image to match model input shape
# Convert to numpy array
image_array = np.array(image)
# Check if image has 3 channels (RGB)
if len(image_array.shape) == 2: # Grayscale
image_array = np.stack((image_array,) * 3, axis=-1)
elif image_array.shape[2] == 4: # RGBA
image_array = image_array[:, :, :3]
# Normalize pixel values to [0, 1]
image_array = image_array.astype('float32') / 255.0
# Add batch dimension
image_array = np.expand_dims(image_array, axis=0)
return image_array, wadah_Z[2]
@app.get("/")
async def root():
return {"message": "SIFRACT-ML API", "input_shape": INPUT_SHAPE}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
Endpoint for making predictions with the Keras model.
Accepts an image file and returns model predictions.
"""
try:
# Read the image file
contents = await file.read()
# Preprocess the image
processed_image, height = preprocess_image(BytesIO(contents))
# Make prediction
prediction = model.predict(processed_image)
# Convert numpy array to list for JSON serialization
prediction = prediction.tolist()
prediction[0][0] = (prediction[0][0]/200)*height
return JSONResponse(content={"prediction": prediction})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")