File size: 6,059 Bytes
608bcbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65e94d4
7701f17
608bcbb
 
 
65e94d4
608bcbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65e94d4
608bcbb
65e94d4
608bcbb
 
65e94d4
608bcbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65e94d4
608bcbb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65e94d4
608bcbb
 
 
 
 
 
 
 
 
65e94d4
608bcbb
 
 
 
 
 
65e94d4
 
608bcbb
65e94d4
608bcbb
 
65e94d4
 
 
 
608bcbb
65e94d4
608bcbb
 
 
 
 
 
 
 
 
65e94d4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
from fastapi import FastAPI, UploadFile, File
import joblib
import pandas as pd
import numpy as np
from pydantic import BaseModel
import os
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import io

app = FastAPI(title="Thaqafini API - Multi-Model Server")

# ==============================
# المتغيرات العالمية للموديلات
# ==============================
maternal_model = None
genetic_model = None
food_model = None
food_processor = None

# التعديل: استخدام النسخة المتخصصة في Food-101 (نفس الـ 101 صنف في الفايربيس)
FOOD_MODEL_CHECKPOINT = "nateraw/vit-base-food101"

# ==============================
# تحميل الموديلات عند تشغيل السيرفر
# ==============================
@app.on_event("startup")
async def load_models():
    global maternal_model, genetic_model, food_model, food_processor
    
    # 1. تحميل موديل صحة الأم
    try:
        if os.path.exists("random_forest_model.joblib"):
            maternal_model = joblib.load("random_forest_model.joblib")
            print("✅ Maternal model loaded successfully")
    except Exception as e:
        print(f"❌ Error loading Maternal model: {e}")

    # 2. تحميل موديل الأمراض الوراثية
    try:
        if os.path.exists("thaqafni_model.pkl"):
            genetic_model = joblib.load("thaqafni_model.pkl")
            print("✅ Genetic model loaded successfully")
    except Exception as e:
        print(f"❌ Error loading Genetic model: {e}")

    # 3. تحميل موديل الطعام المخصص (Food-101)
    try:
        print(f"🔄 Loading specialized Food-101 model ({FOOD_MODEL_CHECKPOINT})...")
        food_processor = AutoImageProcessor.from_pretrained(FOOD_MODEL_CHECKPOINT)
        food_model = AutoModelForImageClassification.from_pretrained(FOOD_MODEL_CHECKPOINT)
        print("✅ Food-101 model (Specialized) loaded successfully")
    except Exception as e:
        print(f"❌ Error loading Food model: {e}")

# ==============================
# نماذج البيانات (Pydantic)
# ==============================
class MaternalInput(BaseModel):
    age: int
    systolic_bp: int
    diastolic_bp: int
    bs: float
    body_temp: float
    heart_rate: int

class GeneticInput(BaseModel):
    age: int
    family_history: int
    hemoglobin: float
    fetal_hemoglobin: float
    sweat_chloride: float
    sickled_rbc_percent: float

# ==============================
# المسارات (Endpoints)
# ==============================

@app.get("/")
def home():
    return {
        "status": "online",
        "models_status": {
            "maternal": "Ready" if maternal_model else "Not Loaded",
            "genetic": "Ready" if genetic_model else "Not Loaded",
            "food_101_specialized": "Ready" if food_model else "Not Loaded"
        }
    }

# 1. توقع مخاطر الأم
@app.post("/predict_maternal")
async def predict_maternal(data: MaternalInput):
    if not maternal_model:
        return {"error": "Maternal model is not available"}
    
    features = np.array([[
        data.age, data.systolic_bp, data.diastolic_bp, 
        data.bs, data.body_temp, data.heart_rate
    ]])
    prediction = maternal_model.predict(features)
    return {"risk_level": int(prediction[0])}

# 2. توقع الأمراض الوراثية
@app.post("/predict_genetic")
async def predict_genetic(data: GeneticInput):
    if not genetic_model:
        return {"error": "Genetic model is not available"}
    
    input_data = pd.DataFrame([[
        data.age, data.family_history, data.hemoglobin,
        data.fetal_hemoglobin, data.sweat_chloride, data.sickled_rbc_percent
    ]], columns=['Age', 'Family_History', 'Hemoglobin', 'Fetal_Hemoglobin', 'Sweat_Chloride', 'Sickled_RBC_Percent'])
    
    prediction = genetic_model.predict(input_data)[0]
    probabilities = genetic_model.predict_proba(input_data)[0]
    confidence = float(np.max(probabilities) * 100)
    
    ar_map = {
        "Thalassemia": "ثلاسيميا",
        "Normal": "سليم - طبيعي",
        "Sickle Cell Anemia": "فقر الدم المنجلي",
        "Cystic Fibrosis": "تليف كيسي",
        "High Risk": "معرض لخطورة عالية"
    }
    
    return {
        "diagnosis": prediction,
        "diagnosis_ar": ar_map.get(prediction, "غير معروف"),
        "confidence": f"{confidence:.2f}%"
    }

# 3. التعرف على الطعام (النسخة المتوافقة مع Firestore)
@app.post("/predict_food")
async def predict_food(file: UploadFile = File(...)):
    if not food_model or not food_processor:
        return {"error": "Food model is not available"}
    
    try:
        image_bytes = await file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
        
        # المعالجة الخاصة بموديلات ViT
        inputs = food_processor(images=image, return_tensors="pt")
        
        with torch.no_grad():
            outputs = food_model(**inputs)
        
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
        
        # استخراج أفضل 3 احتمالات
        top_probs, top_indices = torch.topk(probs, 3)
        
        predictions = []
        for i in range(3):
            label = food_model.config.id2label[top_indices[0][i].item()]
            # ملاحظة: الموديل قد يخرج اسماً بمسافات، نحولها لـ _ لتطابق Firestore IDs
            formatted_label = label.lower().replace(" ", "_")
            
            predictions.append({
                "label": formatted_label,
                "confidence": f"{top_probs[0][i].item() * 100:.2f}%"
            })

        return {
            "main_prediction": predictions[0]["label"],
            "all_predictions": predictions,
            "status": "success"
        }
    except Exception as e:
        return {"error": str(e), "status": "failed"}