sace / main.py
bharathivijay's picture
Update main.py
a1e43ab verified
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image, UnidentifiedImageError
import io
import numpy as np
import os
import gdown
from tensorflow.keras.models import load_model
app = FastAPI()
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, restrict this
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Model setup
GDRIVE_FILE_ID = "18Fux2G1e8uuKFD5coZGj4T5OXKn26DcK"
MODEL_FILENAME = "face_shape_model_optimized.keras"
GDRIVE_URL = f"https://drive.google.com/uc?id={GDRIVE_FILE_ID}"
TMP_DIR = "/tmp" # Universal safe writeable location
def download_model():
print(f"Downloading model from Google Drive: {GDRIVE_URL}")
# Set GDOWN cache directory to a writable path
os.environ["GDOWN_CACHE_DIR"] = TMP_DIR
os.environ["XDG_CACHE_HOME"] = TMP_DIR
# Actually download
output_path = os.path.join(TMP_DIR, MODEL_FILENAME)
if not os.path.exists(output_path):
gdown.download(GDRIVE_URL, output_path, quiet=False)
return output_path
# Load model
model_path = download_model()
model = load_model(model_path)
class_labels = ["diamond", "heart", "oval", "round", "square"]
def preprocess_image(image: Image.Image) -> np.ndarray:
image = image.convert("RGB")
image = image.resize((224, 224))
image_array = np.array(image).astype("float32") / 255.0
return np.expand_dims(image_array, axis=0)
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
if file.content_type not in ["image/png", "image/jpeg"]:
return JSONResponse(status_code=400, content={"error": "Only PNG or JPG images are allowed"})
try:
contents = await file.read()
image = Image.open(io.BytesIO(contents))
processed_image = preprocess_image(image)
predictions = model.predict(processed_image)
predicted_class = class_labels[np.argmax(predictions)]
return {"result": f"Detected face shape: {predicted_class}"}
except UnidentifiedImageError:
return JSONResponse(status_code=400, content={"error": "Uploaded file is not a valid image."})
except Exception as e:
return JSONResponse(status_code=500, content={"error": f"An error occurred: {str(e)}"})
@app.get("/")
def root():
return {"message": "Face Shape API is live!"}
@app.get("/health")
async def health():
return {"status": "ok"}