hanya70999 commited on
Commit
458ccea
·
verified ·
1 Parent(s): 7baf7b6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -0
app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import io
4
+ import base64
5
+ from datetime import datetime
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torchvision import transforms, models
10
+ import joblib
11
+ from PIL import Image
12
+ from flask import Flask, request, jsonify
13
+ from flask_cors import CORS
14
+ from supabase import create_client, Client
15
+
16
+ app = Flask(__name__)
17
+ CORS(app)
18
+
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+ print(f"Using device: {device}")
21
+
22
+ MODEL_DIR = os.path.join(os.path.dirname(__file__), "models")
23
+ model_path = os.path.join(MODEL_DIR, "svm_densenet201_rbf.joblib")
24
+ meta_path = os.path.join(MODEL_DIR, "metadata.json")
25
+
26
+ svm_model = None
27
+ class_names = None
28
+ IMG_SIZE = 224
29
+
30
+ supabase_url = os.environ.get('SUPABASE_URL')
31
+ supabase_key = os.environ.get('SUPABASE_ANON_KEY')
32
+ supabase: Client = None
33
+
34
+ if supabase_url and supabase_key:
35
+ try:
36
+ supabase = create_client(supabase_url, supabase_key)
37
+ print("✓ Supabase client initialized")
38
+ except Exception as e:
39
+ print(f"⚠ Failed to initialize Supabase: {e}")
40
+ supabase = None
41
+ else:
42
+ print("⚠ Supabase credentials not found, predictions won't be saved to database")
43
+
44
+ def load_model():
45
+ global svm_model, class_names, IMG_SIZE
46
+
47
+ try:
48
+ if os.path.exists(model_path):
49
+ svm_model = joblib.load(model_path)
50
+ print("✓ SVM model loaded successfully")
51
+ else:
52
+ print(f"⚠ Model file not found at {model_path}")
53
+ print(" Using simulation mode until model is uploaded")
54
+ svm_model = None
55
+
56
+ if os.path.exists(meta_path):
57
+ with open(meta_path, "r") as f:
58
+ meta = json.load(f)
59
+ class_names = meta.get("class_names", ["3 Bulan", "6 Bulan", "9 Bulan"])
60
+ IMG_SIZE = meta.get("img_size", 224)
61
+ print(f"✓ Metadata loaded: {class_names}")
62
+ else:
63
+ class_names = ["3 Bulan", "6 Bulan", "9 Bulan"]
64
+ print(f"⚠ Metadata not found, using default classes: {class_names}")
65
+
66
+ except Exception as e:
67
+ print(f"Error loading model: {str(e)}")
68
+ svm_model = None
69
+ class_names = ["3 Bulan", "6 Bulan", "9 Bulan"]
70
+
71
+ densenet = models.densenet201(weights=models.DenseNet201_Weights.DEFAULT)
72
+ densenet.eval()
73
+ feature_extractor = densenet.features.to(device)
74
+ gap = nn.AdaptiveAvgPool2d((1, 1)).to(device)
75
+
76
+ eval_tfms = transforms.Compose([
77
+ transforms.Resize((IMG_SIZE, IMG_SIZE)),
78
+ transforms.ToTensor(),
79
+ transforms.Normalize([0.485, 0.456, 0.406],
80
+ [0.229, 0.224, 0.225]),
81
+ ])
82
+
83
+ def decode_base64_image(base64_string):
84
+ if ',' in base64_string:
85
+ base64_string = base64_string.split(',')[1]
86
+
87
+ image_data = base64.b64decode(base64_string)
88
+ image = Image.open(io.BytesIO(image_data)).convert("RGB")
89
+ return image
90
+
91
+ def preprocess_image(image):
92
+ x = eval_tfms(image).unsqueeze(0)
93
+ return x
94
+
95
+ @torch.no_grad()
96
+ def extract_features(img_tensor):
97
+ img_tensor = img_tensor.to(device)
98
+ feats = feature_extractor(img_tensor)
99
+ feats = torch.relu(feats)
100
+ feats = gap(feats)
101
+ feats = feats.view(feats.size(0), -1)
102
+ return feats.cpu().numpy()
103
+
104
+ def simulate_prediction():
105
+ probabilities = np.random.dirichlet(np.ones(len(class_names)), size=1)[0]
106
+ pred_idx = int(np.argmax(probabilities))
107
+ pred_label = class_names[pred_idx]
108
+ confidence = float(probabilities[pred_idx])
109
+
110
+ return pred_label, confidence, probabilities
111
+
112
+ def predict_with_model(features):
113
+ proba = svm_model.predict_proba(features)[0]
114
+ pred_idx = int(np.argmax(proba))
115
+ pred_label = class_names[pred_idx]
116
+ confidence = float(proba[pred_idx])
117
+
118
+ return pred_label, confidence, proba
119
+
120
+ @app.route('/health', methods=['GET'])
121
+ def health_check():
122
+ return jsonify({
123
+ 'status': 'healthy',
124
+ 'model_loaded': svm_model is not None,
125
+ 'device': str(device),
126
+ 'classes': class_names
127
+ })
128
+
129
+ def save_to_database(pred_label, confidence, prob_dict, mode, image_data_url=None):
130
+ if not supabase:
131
+ return None
132
+
133
+ try:
134
+ prediction_data = {
135
+ 'predicted_class': pred_label,
136
+ 'confidence': confidence,
137
+ 'probabilities': prob_dict,
138
+ 'mode': mode,
139
+ 'created_at': datetime.utcnow().isoformat()
140
+ }
141
+
142
+ if image_data_url:
143
+ prediction_data['image_data'] = image_data_url[:1000]
144
+
145
+ result = supabase.table('predictions').insert(prediction_data).execute()
146
+ return result.data[0] if result.data else None
147
+ except Exception as e:
148
+ print(f"⚠ Failed to save to database: {e}")
149
+ return None
150
+
151
+ @app.route('/classify', methods=['POST'])
152
+ def classify_image():
153
+ try:
154
+ data = request.json
155
+
156
+ if not data or 'image' not in data:
157
+ return jsonify({'error': 'No image data provided'}), 400
158
+
159
+ image_base64 = data['image']
160
+ image = decode_base64_image(image_base64)
161
+
162
+ img_tensor = preprocess_image(image)
163
+
164
+ if svm_model is not None:
165
+ features = extract_features(img_tensor)
166
+ pred_label, confidence, probabilities = predict_with_model(features)
167
+ else:
168
+ pred_label, confidence, probabilities = simulate_prediction()
169
+
170
+ prob_dict = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
171
+ mode = 'real' if svm_model is not None else 'simulation'
172
+
173
+ db_record = save_to_database(pred_label, confidence, prob_dict, mode, data['image'])
174
+
175
+ response = {
176
+ 'predicted_class': pred_label,
177
+ 'confidence': confidence,
178
+ 'probabilities': prob_dict,
179
+ 'mode': mode
180
+ }
181
+
182
+ if db_record:
183
+ response['id'] = db_record.get('id')
184
+ response['saved_to_db'] = True
185
+ else:
186
+ response['saved_to_db'] = False
187
+
188
+ return jsonify(response)
189
+
190
+ except Exception as e:
191
+ return jsonify({
192
+ 'error': 'Classification failed',
193
+ 'message': str(e)
194
+ }), 500
195
+
196
+ @app.route('/reload-model', methods=['POST'])
197
+ def reload_model():
198
+ try:
199
+ load_model()
200
+ return jsonify({
201
+ 'status': 'success',
202
+ 'model_loaded': svm_model is not None,
203
+ 'classes': class_names
204
+ })
205
+ except Exception as e:
206
+ return jsonify({
207
+ 'status': 'error',
208
+ 'message': str(e)
209
+ }), 500
210
+
211
+ if __name__ == '__main__':
212
+ os.makedirs(MODEL_DIR, exist_ok=True)
213
+ load_model()
214
+
215
+ port = int(os.environ.get('PORT', 5000))
216
+ app.run(host='0.0.0.0', port=port, debug=False)