hanya70999 commited on
Commit
afb5ef1
·
verified ·
1 Parent(s): d35eb0d

Upload 6 files

Browse files
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ COPY requirements.txt .
10
+
11
+ RUN pip install --no-cache-dir -r requirements.txt
12
+
13
+ COPY . .
14
+
15
+ RUN mkdir -p models
16
+
17
+ EXPOSE 7860
18
+
19
+ CMD ["gunicorn", "--bind", "0.0.0.0:7860", "--workers", "2", "--timeout", "120", "app:app"]
app.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import io
4
+ import base64
5
+ from datetime import datetime
6
+ from threading import Lock
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ from torchvision import transforms, models
12
+ import joblib
13
+ from PIL import Image
14
+ from flask import Flask, request, jsonify
15
+ from flask_cors import CORS
16
+ from supabase import create_client, Client
17
+
18
+ # =========================
19
+ # Flask App
20
+ # =========================
21
+ app = Flask(__name__)
22
+ CORS(app)
23
+
24
+ # =========================
25
+ # Device
26
+ # =========================
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ print(f"Using device: {device}")
29
+
30
+ # =========================
31
+ # Paths
32
+ # =========================
33
+ MODEL_DIR = os.path.join(os.path.dirname(__file__), "models")
34
+ model_path = os.path.join(MODEL_DIR, "svm_densenet201_rbf.joblib")
35
+ meta_path = os.path.join(MODEL_DIR, "metadata.json")
36
+
37
+ # =========================
38
+ # Globals (Models & Config)
39
+ # =========================
40
+ svm_model = None
41
+ class_names = None
42
+ IMG_SIZE = 224
43
+
44
+ # DenseNet globals
45
+ densenet = None
46
+ feature_extractor = None
47
+ gap = None
48
+
49
+ # Transform global (will be built after metadata loaded)
50
+ eval_tfms = None
51
+
52
+ # Load flags + lock (safe for concurrent requests)
53
+ model_loaded = False
54
+ densenet_loaded = False
55
+ load_lock = Lock()
56
+
57
+ # =========================
58
+ # Supabase
59
+ # =========================
60
+ supabase_url = os.environ.get("SUPABASE_URL")
61
+ supabase_key = os.environ.get("SUPABASE_ANON_KEY")
62
+ supabase: Client = None
63
+
64
+ if supabase_url and supabase_key:
65
+ try:
66
+ supabase = create_client(supabase_url, supabase_key)
67
+ print("✓ Supabase client initialized")
68
+ except Exception as e:
69
+ print(f"⚠ Failed to initialize Supabase: {e}")
70
+ supabase = None
71
+ else:
72
+ print("⚠ Supabase credentials not found, predictions won't be saved to database")
73
+
74
+
75
+ # =========================
76
+ # Helpers
77
+ # =========================
78
+ def format_class_name(raw_name: str) -> str:
79
+ """Convert usia_3_bulan to 3 Bulan for display"""
80
+ mapping = {
81
+ "usia_3_bulan": "3 Bulan",
82
+ "usia_6_bulan": "6 Bulan",
83
+ "usia_9_bulan": "9 Bulan"
84
+ }
85
+ return mapping.get(raw_name, raw_name)
86
+
87
+
88
+ def build_eval_transforms(img_size: int):
89
+ """Build transforms using current IMG_SIZE"""
90
+ return transforms.Compose([
91
+ transforms.Resize((img_size, img_size)),
92
+ transforms.ToTensor(),
93
+ transforms.Normalize([0.485, 0.456, 0.406],
94
+ [0.229, 0.224, 0.225]),
95
+ ])
96
+
97
+
98
+ def decode_base64_image(base64_string: str) -> Image.Image:
99
+ if "," in base64_string:
100
+ base64_string = base64_string.split(",")[1]
101
+ image_data = base64.b64decode(base64_string)
102
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
103
+ return image
104
+
105
+
106
+ def preprocess_image(image: Image.Image) -> torch.Tensor:
107
+ global eval_tfms
108
+ if eval_tfms is None:
109
+ # fallback if metadata not yet loaded
110
+ eval_tfms = build_eval_transforms(IMG_SIZE)
111
+ x = eval_tfms(image).unsqueeze(0)
112
+ return x
113
+
114
+
115
+ # =========================
116
+ # Loading: SVM + Metadata
117
+ # =========================
118
+ def load_model():
119
+ """
120
+ Load SVM + metadata safely (works under gunicorn too).
121
+ Lazy loaded on first request /classify.
122
+ """
123
+ global svm_model, class_names, IMG_SIZE, model_loaded, eval_tfms
124
+
125
+ if model_loaded:
126
+ return
127
+
128
+ with load_lock:
129
+ if model_loaded:
130
+ return
131
+
132
+ os.makedirs(MODEL_DIR, exist_ok=True)
133
+
134
+ try:
135
+ print(f"🔍 Checking model directory: {MODEL_DIR}")
136
+ print(f" Model path: {model_path}")
137
+ print(f" Metadata path: {meta_path}")
138
+ print(f" Model exists: {os.path.exists(model_path)}")
139
+ print(f" Metadata exists: {os.path.exists(meta_path)}")
140
+
141
+ if os.path.exists(MODEL_DIR):
142
+ files = os.listdir(MODEL_DIR)
143
+ print(f" Files in models/: {files}")
144
+
145
+ # ---- Load SVM ----
146
+ if os.path.exists(model_path):
147
+ print("⏳ Loading SVM model...")
148
+ svm_model = joblib.load(model_path)
149
+ print("✓ SVM model loaded successfully")
150
+ else:
151
+ print(f"⚠ Model file not found at {model_path}")
152
+ print(" Using simulation mode until model is uploaded")
153
+ svm_model = None
154
+
155
+ # ---- Load Metadata ----
156
+ if os.path.exists(meta_path):
157
+ with open(meta_path, "r") as f:
158
+ meta = json.load(f)
159
+ class_names = meta.get("class_names", ["usia_3_bulan", "usia_6_bulan", "usia_9_bulan"])
160
+ IMG_SIZE = int(meta.get("img_size", 224))
161
+ print(f"✓ Metadata loaded: class_names={class_names}, IMG_SIZE={IMG_SIZE}")
162
+ else:
163
+ class_names = ["usia_3_bulan", "usia_6_bulan", "usia_9_bulan"]
164
+ IMG_SIZE = 224
165
+ print(f"⚠ Metadata not found, using default classes: {class_names}, IMG_SIZE={IMG_SIZE}")
166
+
167
+ # IMPORTANT: rebuild transforms after IMG_SIZE updated
168
+ eval_tfms = build_eval_transforms(IMG_SIZE)
169
+
170
+ model_loaded = True
171
+
172
+ except Exception as e:
173
+ print(f"❌ Error loading model: {str(e)}")
174
+ import traceback
175
+ traceback.print_exc()
176
+
177
+ svm_model = None
178
+ class_names = ["usia_3_bulan", "usia_6_bulan", "usia_9_bulan"]
179
+ IMG_SIZE = 224
180
+ eval_tfms = build_eval_transforms(IMG_SIZE)
181
+ model_loaded = True
182
+
183
+
184
+ # =========================
185
+ # Loading: DenseNet201
186
+ # =========================
187
+ def load_densenet():
188
+ global densenet, feature_extractor, gap, densenet_loaded
189
+
190
+ if densenet_loaded:
191
+ return
192
+
193
+ with load_lock:
194
+ if densenet_loaded:
195
+ return
196
+
197
+ print("⏳ Loading DenseNet201 (first time may take a while)...")
198
+ densenet = models.densenet201(weights=models.DenseNet201_Weights.DEFAULT)
199
+ densenet.eval()
200
+ feature_extractor = densenet.features.to(device)
201
+ gap = nn.AdaptiveAvgPool2d((1, 1)).to(device)
202
+ densenet_loaded = True
203
+ print("✓ DenseNet201 loaded successfully")
204
+
205
+
206
+ @torch.no_grad()
207
+ def extract_features(img_tensor: torch.Tensor) -> np.ndarray:
208
+ load_densenet()
209
+ img_tensor = img_tensor.to(device)
210
+ feats = feature_extractor(img_tensor)
211
+ feats = torch.relu(feats)
212
+ feats = gap(feats)
213
+ feats = feats.view(feats.size(0), -1)
214
+ return feats.cpu().numpy()
215
+
216
+
217
+ # =========================
218
+ # Prediction
219
+ # =========================
220
+ def simulate_prediction():
221
+ if not class_names:
222
+ _classes = ["usia_3_bulan", "usia_6_bulan", "usia_9_bulan"]
223
+ else:
224
+ _classes = class_names
225
+
226
+ probabilities = np.random.dirichlet(np.ones(len(_classes)), size=1)[0]
227
+ pred_idx = int(np.argmax(probabilities))
228
+ pred_label = _classes[pred_idx]
229
+ confidence = float(probabilities[pred_idx])
230
+ return pred_label, confidence, probabilities
231
+
232
+
233
+ def predict_with_model(features: np.ndarray):
234
+ proba = svm_model.predict_proba(features)[0]
235
+ pred_idx = int(np.argmax(proba))
236
+ pred_label = class_names[pred_idx]
237
+ confidence = float(proba[pred_idx])
238
+ return pred_label, confidence, proba
239
+
240
+
241
+ # =========================
242
+ # Database Save
243
+ # =========================
244
+ def save_to_database(pred_label, confidence, prob_dict, mode, image_data_url=None):
245
+ if not supabase:
246
+ return None
247
+
248
+ try:
249
+ prediction_data = {
250
+ "predicted_class": pred_label,
251
+ "confidence": float(confidence),
252
+ "probabilities": prob_dict,
253
+ "mode": mode,
254
+ "created_at": datetime.utcnow().isoformat(),
255
+ }
256
+
257
+ if image_data_url:
258
+ # truncate for safety
259
+ prediction_data["image_data"] = image_data_url[:1000]
260
+ # Save full image for display
261
+ prediction_data["image_url"] = image_data_url
262
+
263
+ result = supabase.table("predictions").insert(prediction_data).execute()
264
+ return result.data[0] if result.data else None
265
+ except Exception as e:
266
+ print(f"⚠ Failed to save to database: {e}")
267
+ return None
268
+
269
+
270
+ # =========================
271
+ # Routes
272
+ # =========================
273
+ @app.route("/", methods=["GET"])
274
+ def home():
275
+ return jsonify({
276
+ "service": "Seedling Classifier API",
277
+ "status": "running",
278
+ "version": "1.0.0",
279
+ "endpoints": {
280
+ "health": "/health",
281
+ "classify": "/classify (POST)",
282
+ "reload_model": "/reload-model (POST)",
283
+ "warmup": "/warmup (POST)",
284
+ },
285
+ "note": "Open /health to verify. Use POST /classify with JSON {image: base64DataURL}."
286
+ })
287
+
288
+
289
+ @app.route("/health", methods=["GET"])
290
+ def health_check():
291
+ default_classes = ["usia_3_bulan", "usia_6_bulan", "usia_9_bulan"]
292
+ current_classes = class_names if class_names else default_classes
293
+ display_classes = [format_class_name(c) for c in current_classes]
294
+
295
+ return jsonify({
296
+ "status": "healthy",
297
+ "model_loaded": svm_model is not None,
298
+ "densenet_loaded": feature_extractor is not None,
299
+ "device": str(device),
300
+ "classes": display_classes,
301
+ "ready": True
302
+ })
303
+
304
+
305
+ @app.route("/classify", methods=["POST"])
306
+ def classify_image():
307
+ try:
308
+ # Lazy-load model + metadata on first request
309
+ if not model_loaded:
310
+ load_model()
311
+
312
+ data = request.get_json(silent=True)
313
+
314
+ if not data or "image" not in data:
315
+ return jsonify({"error": "No image data provided"}), 400
316
+
317
+ image_base64 = data["image"]
318
+ image = decode_base64_image(image_base64)
319
+ img_tensor = preprocess_image(image)
320
+
321
+ # Use real model if available, else simulation mode
322
+ if svm_model is not None:
323
+ features = extract_features(img_tensor)
324
+ pred_label, confidence, probabilities = predict_with_model(features)
325
+ mode = "real"
326
+ else:
327
+ pred_label, confidence, probabilities = simulate_prediction()
328
+ mode = "simulation"
329
+
330
+ # Ensure class_names exists
331
+ _classes = class_names if class_names else ["usia_3_bulan", "usia_6_bulan", "usia_9_bulan"]
332
+
333
+ prob_dict = {format_class_name(_classes[i]): float(probabilities[i]) for i in range(len(_classes))}
334
+ formatted_pred_label = format_class_name(pred_label)
335
+
336
+ db_record = save_to_database(formatted_pred_label, confidence, prob_dict, mode, data.get("image"))
337
+
338
+ response = {
339
+ "predicted_class": formatted_pred_label,
340
+ "confidence": float(confidence),
341
+ "probabilities": prob_dict,
342
+ "mode": mode,
343
+ "saved_to_db": bool(db_record),
344
+ }
345
+
346
+ if db_record:
347
+ response["id"] = db_record.get("id")
348
+
349
+ return jsonify(response)
350
+
351
+ except Exception as e:
352
+ return jsonify({
353
+ "error": "Classification failed",
354
+ "message": str(e)
355
+ }), 500
356
+
357
+
358
+ @app.route("/reload-model", methods=["POST"])
359
+ def reload_model_route():
360
+ global model_loaded, svm_model, class_names, eval_tfms
361
+
362
+ try:
363
+ with load_lock:
364
+ model_loaded = False
365
+ svm_model = None
366
+ class_names = None
367
+ eval_tfms = None
368
+
369
+ load_model()
370
+
371
+ display_classes = [format_class_name(c) for c in class_names] if class_names else []
372
+ return jsonify({
373
+ "status": "success",
374
+ "model_loaded": svm_model is not None,
375
+ "classes": display_classes
376
+ })
377
+
378
+ except Exception as e:
379
+ return jsonify({
380
+ "status": "error",
381
+ "message": str(e)
382
+ }), 500
383
+
384
+
385
+ @app.route("/warmup", methods=["POST"])
386
+ def warmup():
387
+ try:
388
+ load_densenet()
389
+ return jsonify({
390
+ "status": "success",
391
+ "densenet_loaded": feature_extractor is not None,
392
+ "device": str(device)
393
+ })
394
+ except Exception as e:
395
+ return jsonify({
396
+ "status": "error",
397
+ "message": str(e)
398
+ }), 500
399
+
400
+
401
+ # =========================
402
+ # Local run (optional)
403
+ # =========================
404
+ if __name__ == "__main__":
405
+ os.makedirs(MODEL_DIR, exist_ok=True)
406
+ print("🚀 Starting locally...")
407
+
408
+ # Optional: uncomment to preload on local run
409
+ # load_model()
410
+ # load_densenet()
411
+
412
+ port = int(os.environ.get("PORT", 7860))
413
+ app.run(host="0.0.0.0", port=port, debug=False)
models/gitkeep ADDED
File without changes
models/metadata.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "split_base": "/content/drive/MyDrive/dataset_split",
3
+ "class_names": [
4
+ "usia_3_bulan",
5
+ "usia_6_bulan",
6
+ "usia_9_bulan"
7
+ ],
8
+ "class_to_idx": {
9
+ "usia_3_bulan": 0,
10
+ "usia_6_bulan": 1,
11
+ "usia_9_bulan": 2
12
+ },
13
+ "img_size": 224,
14
+ "feature_dim": 1920,
15
+ "best_params": {
16
+ "C": 10,
17
+ "gamma": 0.001,
18
+ "kernel": "rbf"
19
+ }
20
+ }
models/svm_densenet201_rbf.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78ea6d2911660bb66aa1c53ee9c59ed4e38178bfd7e221e8df8db84336263cf3
3
+ size 5467299
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ flask==3.0.0
2
+ flask-cors==4.0.0
3
+ torch==2.1.0
4
+ torchvision==0.16.0
5
+ numpy==1.24.3
6
+ Pillow==10.1.0
7
+ scikit-learn==1.3.2
8
+ joblib==1.3.2
9
+ gunicorn==21.2.0
10
+ supabase==2.9.0