Thaqafini-API / main.py
sadekmarouf's picture
Update main.py
7701f17 verified
from fastapi import FastAPI, UploadFile, File
import joblib
import pandas as pd
import numpy as np
from pydantic import BaseModel
import os
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import io
app = FastAPI(title="Thaqafini API - Multi-Model Server")
# ==============================
# المتغيرات العالمية للموديلات
# ==============================
maternal_model = None
genetic_model = None
food_model = None
food_processor = None
# التعديل: استخدام النسخة المتخصصة في Food-101 (نفس الـ 101 صنف في الفايربيس)
FOOD_MODEL_CHECKPOINT = "nateraw/vit-base-food101"
# ==============================
# تحميل الموديلات عند تشغيل السيرفر
# ==============================
@app.on_event("startup")
async def load_models():
global maternal_model, genetic_model, food_model, food_processor
# 1. تحميل موديل صحة الأم
try:
if os.path.exists("random_forest_model.joblib"):
maternal_model = joblib.load("random_forest_model.joblib")
print("✅ Maternal model loaded successfully")
except Exception as e:
print(f"❌ Error loading Maternal model: {e}")
# 2. تحميل موديل الأمراض الوراثية
try:
if os.path.exists("thaqafni_model.pkl"):
genetic_model = joblib.load("thaqafni_model.pkl")
print("✅ Genetic model loaded successfully")
except Exception as e:
print(f"❌ Error loading Genetic model: {e}")
# 3. تحميل موديل الطعام المخصص (Food-101)
try:
print(f"🔄 Loading specialized Food-101 model ({FOOD_MODEL_CHECKPOINT})...")
food_processor = AutoImageProcessor.from_pretrained(FOOD_MODEL_CHECKPOINT)
food_model = AutoModelForImageClassification.from_pretrained(FOOD_MODEL_CHECKPOINT)
print("✅ Food-101 model (Specialized) loaded successfully")
except Exception as e:
print(f"❌ Error loading Food model: {e}")
# ==============================
# نماذج البيانات (Pydantic)
# ==============================
class MaternalInput(BaseModel):
age: int
systolic_bp: int
diastolic_bp: int
bs: float
body_temp: float
heart_rate: int
class GeneticInput(BaseModel):
age: int
family_history: int
hemoglobin: float
fetal_hemoglobin: float
sweat_chloride: float
sickled_rbc_percent: float
# ==============================
# المسارات (Endpoints)
# ==============================
@app.get("/")
def home():
return {
"status": "online",
"models_status": {
"maternal": "Ready" if maternal_model else "Not Loaded",
"genetic": "Ready" if genetic_model else "Not Loaded",
"food_101_specialized": "Ready" if food_model else "Not Loaded"
}
}
# 1. توقع مخاطر الأم
@app.post("/predict_maternal")
async def predict_maternal(data: MaternalInput):
if not maternal_model:
return {"error": "Maternal model is not available"}
features = np.array([[
data.age, data.systolic_bp, data.diastolic_bp,
data.bs, data.body_temp, data.heart_rate
]])
prediction = maternal_model.predict(features)
return {"risk_level": int(prediction[0])}
# 2. توقع الأمراض الوراثية
@app.post("/predict_genetic")
async def predict_genetic(data: GeneticInput):
if not genetic_model:
return {"error": "Genetic model is not available"}
input_data = pd.DataFrame([[
data.age, data.family_history, data.hemoglobin,
data.fetal_hemoglobin, data.sweat_chloride, data.sickled_rbc_percent
]], columns=['Age', 'Family_History', 'Hemoglobin', 'Fetal_Hemoglobin', 'Sweat_Chloride', 'Sickled_RBC_Percent'])
prediction = genetic_model.predict(input_data)[0]
probabilities = genetic_model.predict_proba(input_data)[0]
confidence = float(np.max(probabilities) * 100)
ar_map = {
"Thalassemia": "ثلاسيميا",
"Normal": "سليم - طبيعي",
"Sickle Cell Anemia": "فقر الدم المنجلي",
"Cystic Fibrosis": "تليف كيسي",
"High Risk": "معرض لخطورة عالية"
}
return {
"diagnosis": prediction,
"diagnosis_ar": ar_map.get(prediction, "غير معروف"),
"confidence": f"{confidence:.2f}%"
}
# 3. التعرف على الطعام (النسخة المتوافقة مع Firestore)
@app.post("/predict_food")
async def predict_food(file: UploadFile = File(...)):
if not food_model or not food_processor:
return {"error": "Food model is not available"}
try:
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# المعالجة الخاصة بموديلات ViT
inputs = food_processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = food_model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
# استخراج أفضل 3 احتمالات
top_probs, top_indices = torch.topk(probs, 3)
predictions = []
for i in range(3):
label = food_model.config.id2label[top_indices[0][i].item()]
# ملاحظة: الموديل قد يخرج اسماً بمسافات، نحولها لـ _ لتطابق Firestore IDs
formatted_label = label.lower().replace(" ", "_")
predictions.append({
"label": formatted_label,
"confidence": f"{top_probs[0][i].item() * 100:.2f}%"
})
return {
"main_prediction": predictions[0]["label"],
"all_predictions": predictions,
"status": "success"
}
except Exception as e:
return {"error": str(e), "status": "failed"}