sadekmarouf commited on
Commit
608bcbb
·
verified ·
1 Parent(s): 1c05cec

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +170 -207
main.py CHANGED
@@ -1,207 +1,170 @@
1
- from fastapi import FastAPI, UploadFile, File
2
- import joblib
3
- import pandas as pd
4
- import numpy as np
5
- from pydantic import BaseModel
6
- import os
7
- from transformers import AutoImageProcessor, AutoModelForImageClassification
8
- from PIL import Image
9
- import torch
10
- import io
11
-
12
- app = FastAPI()
13
-
14
- # ==============================
15
- # المتغيرات العالمية للمودلات
16
- # ==============================
17
-
18
- maternal_model = None
19
- genetic_model = None
20
- food_model = None
21
- food_processor = None
22
-
23
- # ==============================
24
- # تحميل المودلات عند تشغيل السيرفر
25
- # ==============================
26
-
27
- @app.on_event("startup")
28
- def load_models():
29
- global maternal_model, genetic_model, food_model, food_processor
30
-
31
- # تحميل موديل الأم
32
- try:
33
- if os.path.exists("random_forest_model.joblib"):
34
- maternal_model = joblib.load("random_forest_model.joblib")
35
- print("✅ Maternal model loaded successfully")
36
- else:
37
- print("❌ File 'random_forest_model.joblib' NOT found!")
38
- except Exception as e:
39
- print(f"❌ Error loading Maternal model: {e}")
40
-
41
- # تحميل موديل الوراثة
42
- try:
43
- model_name = "thaqafni_model.pkl"
44
- if os.path.exists(model_name):
45
- genetic_model = joblib.load(model_name)
46
- print(f"✅ Genetic model '{model_name}' loaded successfully")
47
- else:
48
- print(f" File '{model_name}' NOT found!")
49
- except Exception as e:
50
- print(f"❌ Error loading Genetic model: {e}")
51
-
52
- # تحميل مودل التعرف على الطعام
53
- try:
54
- food_path = "food101"
55
-
56
- if os.path.exists(food_path):
57
- food_processor = AutoImageProcessor.from_pretrained(food_path)
58
- food_model = AutoModelForImageClassification.from_pretrained(food_path)
59
-
60
- print("✅ Food model loaded successfully")
61
-
62
- else:
63
- print("❌ Folder 'food101' NOT found!")
64
-
65
- except Exception as e:
66
- print(f"❌ Error loading Food model: {e}")
67
-
68
- # ==============================
69
- # نماذج البيانات
70
- # ==============================
71
-
72
- class MaternalInput(BaseModel):
73
- age: int
74
- systolic_bp: int
75
- diastolic_bp: int
76
- bs: float
77
- body_temp: float
78
- heart_rate: int
79
-
80
- class GeneticInput(BaseModel):
81
- age: int
82
- family_history: int
83
- hemoglobin: float
84
- fetal_hemoglobin: float
85
- sweat_chloride: float
86
- sickled_rbc_percent: float
87
-
88
- # ==============================
89
- # الصفحة الرئيسية
90
- # ==============================
91
-
92
- @app.get("/")
93
- def home():
94
- return {
95
- "status": "online",
96
- "maternal_model": "Ready" if maternal_model else "Not Loaded",
97
- "genetic_model": "Ready" if genetic_model else "Not Loaded",
98
- "food_model": "Ready" if food_model else "Not Loaded"
99
- }
100
-
101
- # ==============================
102
- # مودل مخاطر الأم
103
- # ==============================
104
-
105
- @app.post("/predict_maternal")
106
- async def predict_maternal(data: MaternalInput):
107
-
108
- if not maternal_model:
109
- return {"error": "Maternal model is not available"}
110
-
111
- features = np.array([[
112
- data.age,
113
- data.systolic_bp,
114
- data.diastolic_bp,
115
- data.bs,
116
- data.body_temp,
117
- data.heart_rate
118
- ]])
119
-
120
- prediction = maternal_model.predict(features)
121
-
122
- return {
123
- "risk_level": int(prediction[0])
124
- }
125
-
126
- # ==============================
127
- # مودل الأمراض الوراثية
128
- # ==============================
129
-
130
- @app.post("/predict_genetic")
131
- async def predict_genetic(data: GeneticInput):
132
-
133
- if not genetic_model:
134
- return {"error": "Genetic model is not available"}
135
-
136
- input_data = pd.DataFrame([[
137
- data.age,
138
- data.family_history,
139
- data.hemoglobin,
140
- data.fetal_hemoglobin,
141
- data.sweat_chloride,
142
- data.sickled_rbc_percent
143
- ]],
144
- columns=[
145
- 'Age',
146
- 'Family_History',
147
- 'Hemoglobin',
148
- 'Fetal_Hemoglobin',
149
- 'Sweat_Chloride',
150
- 'Sickled_RBC_Percent'
151
- ])
152
-
153
- prediction = genetic_model.predict(input_data)[0]
154
-
155
- probabilities = genetic_model.predict_proba(input_data)[0]
156
- confidence = float(np.max(probabilities) * 100)
157
-
158
- ar_map = {
159
- "Thalassemia": "ثلاسيميا",
160
- "Normal": "سليم - طبيعي",
161
- "Sickle Cell Anemia": "فقر الدم المنجلي",
162
- "Cystic Fibrosis": "تليف كيسي",
163
- "High Risk": "معرض لخطورة عالية"
164
- }
165
-
166
- return {
167
- "diagnosis": prediction,
168
- "diagnosis_ar": ar_map.get(prediction, "غير معروف"),
169
- "confidence": f"{confidence:.2f}%",
170
- "status": "success"
171
- }
172
-
173
- # ==============================
174
- # مودل التعرف على ا��طعام
175
- # ==============================
176
-
177
- @app.post("/predict_food")
178
- async def predict_food(file: UploadFile = File(...)):
179
-
180
- if not food_model:
181
- return {"error": "Food model is not available"}
182
-
183
- try:
184
- image_bytes = await file.read()
185
-
186
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
187
-
188
- inputs = food_processor(images=image, return_tensors="pt")
189
-
190
- with torch.no_grad():
191
- outputs = food_model(**inputs)
192
-
193
- logits = outputs.logits
194
- predicted_class_id = logits.argmax(-1).item()
195
-
196
- food_name = food_model.config.id2label[predicted_class_id]
197
-
198
- return {
199
- "food_id": predicted_class_id,
200
- "food_name": food_name,
201
- "status": "success"
202
- }
203
-
204
- except Exception as e:
205
- return {
206
- "error": str(e)
207
- }
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ import joblib
3
+ import pandas as pd
4
+ import numpy as np
5
+ from pydantic import BaseModel
6
+ import os
7
+ from transformers import AutoImageProcessor, AutoModelForImageClassification
8
+ from PIL import Image
9
+ import torch
10
+ import io
11
+
12
+ app = FastAPI(title="Thaqafini API - Multi-Model Server")
13
+
14
+ # ==============================
15
+ # المتغيرات العالمية للموديلات
16
+ # ==============================
17
+ maternal_model = None
18
+ genetic_model = None
19
+ food_model = None
20
+ food_processor = None
21
+
22
+ # اسم الموديل العالمي لـ 101 صنف طعام
23
+ FOOD_MODEL_CHECKPOINT = "chriamue/vit-finetuned-food101"
24
+
25
+ # ==============================
26
+ # تحميل الموديلات عند تشغيل السيرفر
27
+ # =============================
28
+ @app.on_event("startup")
29
+ async def load_models():
30
+ global maternal_model, genetic_model, food_model, food_processor
31
+
32
+ # 1. تحميل موديل صحة الأم
33
+ try:
34
+ if os.path.exists("random_forest_model.joblib"):
35
+ maternal_model = joblib.load("random_forest_model.joblib")
36
+ print("✅ Maternal model loaded successfully")
37
+ else:
38
+ print("⚠️ Maternal model file not found.")
39
+ except Exception as e:
40
+ print(f"❌ Error loading Maternal model: {e}")
41
+
42
+ # 2. تحميل موديل الأمراض الوراثية
43
+ try:
44
+ if os.path.exists("thaqafni_model.pkl"):
45
+ genetic_model = joblib.load("thaqafni_model.pkl")
46
+ print("✅ Genetic model loaded successfully")
47
+ else:
48
+ print("⚠️ Genetic model file not found.")
49
+ except Exception as e:
50
+ print(f"❌ Error loading Genetic model: {e}")
51
+
52
+ # 3. تحميل موديل الطعام الشامل (101 صنف) من Hugging Face
53
+ try:
54
+ print(f"🔄 Loading Food-101 model ({FOOD_MODEL_CHECKPOINT})...")
55
+ food_processor = AutoImageProcessor.from_pretrained(FOOD_MODEL_CHECKPOINT)
56
+ food_model = AutoModelForImageClassification.from_pretrained(FOOD_MODEL_CHECKPOINT)
57
+ print("✅ Food-101 model (101 Classes) loaded successfully")
58
+ except Exception as e:
59
+ print(f"❌ Error loading Food model: {e}")
60
+
61
+ # ==============================
62
+ # نماذج البيانات (Pydantic)
63
+ # ==============================
64
+ class MaternalInput(BaseModel):
65
+ age: int
66
+ systolic_bp: int
67
+ diastolic_bp: int
68
+ bs: float
69
+ body_temp: float
70
+ heart_rate: int
71
+
72
+ class GeneticInput(BaseModel):
73
+ age: int
74
+ family_history: int
75
+ hemoglobin: float
76
+ fetal_hemoglobin: float
77
+ sweat_chloride: float
78
+ sickled_rbc_percent: float
79
+
80
+ # ==============================
81
+ # المسارات (Endpoints)
82
+ # ==============================
83
+
84
+ @app.get("/")
85
+ def home():
86
+ return {
87
+ "status": "online",
88
+ "models_status": {
89
+ "maternal": "Ready" if maternal_model else "Not Loaded",
90
+ "genetic": "Ready" if genetic_model else "Not Loaded",
91
+ "food_101": "Ready" if food_model else "Not Loaded"
92
+ }
93
+ }
94
+
95
+ # 1. توقع مخاطر الأم
96
+ @app.post("/predict_maternal")
97
+ async def predict_maternal(data: MaternalInput):
98
+ if not maternal_model:
99
+ return {"error": "Maternal model is not available"}
100
+
101
+ features = np.array([[
102
+ data.age, data.systolic_bp, data.diastolic_bp,
103
+ data.bs, data.body_temp, data.heart_rate
104
+ ]])
105
+ prediction = maternal_model.predict(features)
106
+ return {"risk_level": int(prediction[0])}
107
+
108
+ # 2. توقع الأمراض الوراثية
109
+ @app.post("/predict_genetic")
110
+ async def predict_genetic(data: GeneticInput):
111
+ if not genetic_model:
112
+ return {"error": "Genetic model is not available"}
113
+
114
+ input_data = pd.DataFrame([[
115
+ data.age, data.family_history, data.hemoglobin,
116
+ data.fetal_hemoglobin, data.sweat_chloride, data.sickled_rbc_percent
117
+ ]], columns=['Age', 'Family_History', 'Hemoglobin', 'Fetal_Hemoglobin', 'Sweat_Chloride', 'Sickled_RBC_Percent'])
118
+
119
+ prediction = genetic_model.predict(input_data)[0]
120
+ probabilities = genetic_model.predict_proba(input_data)[0]
121
+ confidence = float(np.max(probabilities) * 100)
122
+
123
+ ar_map = {
124
+ "Thalassemia": "ثلاسيميا",
125
+ "Normal": "سليم - طبيعي",
126
+ "Sickle Cell Anemia": "فقر الدم المنجلي",
127
+ "Cystic Fibrosis": ليف كيسي",
128
+ "High Risk": "معرض لخطورة عالية"
129
+ }
130
+
131
+ return {
132
+ "diagnosis": prediction,
133
+ "diagnosis_ar": ar_map.get(prediction, "غير معروف"),
134
+ "confidence": f"{confidence:.2f}%"
135
+ }
136
+
137
+ # 3. التعرف على 101 صنف طعام
138
+ @app.post("/predict_food")
139
+ async def predict_food(file: UploadFile = File(...)):
140
+ if not food_model or not food_processor:
141
+ return {"error": "Food model is not available"}
142
+
143
+ try:
144
+ # قراءة ومعالجة الصورة
145
+ image_bytes = await file.read()
146
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
147
+
148
+ inputs = food_processor(images=image, return_tensors="pt")
149
+
150
+ with torch.no_grad():
151
+ outputs = food_model(**inputs)
152
+
153
+ # استخراج الاحتمالات وأفضل 3 نتائج
154
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
155
+ top_probs, top_indices = torch.topk(probs, 3)
156
+
157
+ predictions = []
158
+ for i in range(3):
159
+ predictions.append({
160
+ "label": food_model.config.id2label[top_indices[0][i].item()],
161
+ "confidence": f"{top_probs[0][i].item() * 100:.2f}%"
162
+ })
163
+
164
+ return {
165
+ "main_prediction": predictions[0]["label"],
166
+ "all_predictions": predictions,
167
+ "status": "success"
168
+ }
169
+ except Exception as e:
170
+ return {"error": str(e)}