maftuh-main commited on
Commit
a5f6738
·
1 Parent(s): 06761bb

Add safe_mode=False to fix batch_normalization error

Browse files
Files changed (1) hide show
  1. app.py +45 -128
app.py CHANGED
@@ -1,7 +1,6 @@
1
  """
2
  Batik Classifier API - MobileNetV2 Model
3
  95.43% accuracy on 42 batik classes
4
- Efficient mobile/web deployment
5
  """
6
 
7
  import os
@@ -12,191 +11,109 @@ from flask import Flask, request, jsonify
12
  from flask_cors import CORS
13
  from PIL import Image
14
 
15
- # Import TensorFlow
16
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
17
- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
18
-
19
  import tensorflow as tf
20
- from tensorflow import keras
21
- from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
22
 
23
  app = Flask(__name__)
24
  CORS(app)
25
 
26
- # Global variables
27
  model = None
28
  class_names = None
29
  config = None
30
 
 
 
 
 
31
  def load_models():
32
- """Load MobileNetV2 model and class names"""
33
  global model, class_names, config
34
 
35
- model_dir = "models"
36
-
37
  try:
38
- # Load Keras model with compile=False to avoid compatibility issues
39
- model_path = os.path.join(model_dir, "batik_model.keras")
40
- model = keras.models.load_model(model_path, compile=False)
41
-
42
- # Compile manually
43
- model.compile(
44
- optimizer='adam',
45
- loss='categorical_crossentropy',
46
- metrics=['accuracy']
47
- )
48
-
49
- print(f" Loaded MobileNetV2 model from {model_path}")
50
- print(f" Input shape: {model.input_shape}")
51
- print(f" Output shape: {model.output_shape}")
52
- print(f" Total params: {model.count_params():,}")
53
 
54
- # Load class names
55
- classes_path = os.path.join(model_dir, "batik_classes.json")
56
- with open(classes_path, 'r') as f:
57
  class_names = json.load(f)
58
- print(f" Loaded {len(class_names)} batik classes")
59
 
60
- # Load config
61
- config_path = os.path.join(model_dir, "batik_config.json")
62
- with open(config_path, 'r') as f:
63
- config = json.load(f)
64
- print(f" Model config: {config.get('model', 'Unknown')}")
65
- print(f" Train accuracy: {config.get('train_accuracy', 0):.2%}")
66
- print(f" Val accuracy: {config.get('val_accuracy', 0):.2%}")
67
 
68
  return True
69
-
70
  except Exception as e:
71
- print(f" Error loading models: {e}")
72
- import traceback
73
- traceback.print_exc()
74
  return False
75
 
76
- def preprocess_image(image, target_size=(160, 160)):
77
- """Preprocess image for MobileNetV2 (160x160 input)"""
78
  if image.mode != 'RGB':
79
  image = image.convert('RGB')
80
-
81
- # Resize to 160x160 (MobileNetV2 input size)
82
- image = image.resize(target_size, Image.Resampling.LANCZOS)
83
  img_array = np.array(image, dtype=np.float32)
84
  img_array = np.expand_dims(img_array, axis=0)
85
- img_array = preprocess_input(img_array)
86
-
87
- return img_array
88
 
89
- @app.route('/', methods=['GET'])
90
  def index():
91
- """API info endpoint"""
92
  return jsonify({
93
  "name": "Batik Classifier API",
94
  "model": "MobileNetV2",
95
- "description": "Efficient mobile/web batik classifier",
96
  "classes": len(class_names) if class_names else 0,
97
  "accuracy": config.get('val_accuracy', 0) if config else 0,
98
- "train_accuracy": config.get('train_accuracy', 0) if config else 0,
99
- "epochs": config.get('epochs', 0) if config else 0,
100
  "input_size": "160x160",
101
- "endpoints": {
102
- "/": "API info",
103
- "/predict": "POST - Classify batik image",
104
- "/classes": "GET - List all classes",
105
- "/health": "GET - Health check",
106
- "/info": "GET - Model metadata"
107
- }
108
  })
109
 
110
- @app.route('/health', methods=['GET'])
111
  def health():
112
- """Health check endpoint"""
113
- return jsonify({
114
- "status": "healthy",
115
- "model_loaded": model is not None,
116
- "classes_loaded": class_names is not None,
117
- "model_type": "MobileNetV2"
118
- })
119
 
120
- @app.route('/classes', methods=['GET'])
121
  def get_classes():
122
- """Get all batik classes"""
123
- if class_names is None:
124
- return jsonify({"error": "Classes not loaded"}), 500
125
-
126
- return jsonify({
127
- "classes": class_names,
128
- "total": len(class_names)
129
- })
130
 
131
- @app.route('/info', methods=['GET'])
132
  def get_info():
133
- """Get model metadata"""
134
- if config is None:
135
- return jsonify({"error": "Config not loaded"}), 500
136
-
137
- return jsonify(config)
138
 
139
  @app.route('/predict', methods=['POST'])
140
  def predict():
141
- """Classify batik image using MobileNetV2"""
142
-
143
- if model is None or class_names is None:
144
  return jsonify({"error": "Model not loaded"}), 500
145
 
146
- # Check if image in request
147
  if 'image' not in request.files:
148
- return jsonify({"error": "No image provided"}), 400
149
 
150
  try:
151
- # Read and preprocess image
152
- image_file = request.files['image']
153
- image = Image.open(io.BytesIO(image_file.read()))
154
- processed_image = preprocess_image(image)
155
 
156
- # Predict
157
- predictions = model.predict(processed_image, verbose=0)
158
- predicted_idx = np.argmax(predictions[0])
159
- confidence = float(predictions[0][predicted_idx])
160
- predicted_class = class_names[predicted_idx]
161
 
162
- # Get top 5 predictions
163
- top5_idx = np.argsort(predictions[0])[-5:][::-1]
164
- top5_predictions = [
165
- {
166
- "class": class_names[idx],
167
- "confidence": float(predictions[0][idx])
168
- }
169
- for idx in top5_idx
170
- ]
171
 
172
  return jsonify({
173
- "predicted_class": predicted_class,
174
- "confidence": confidence,
175
- "top5_predictions": top5_predictions,
176
- "model": "MobileNetV2"
177
  })
178
-
179
  except Exception as e:
180
- import traceback
181
- return jsonify({
182
- "error": str(e),
183
- "traceback": traceback.format_exc()
184
- }), 500
185
-
186
- # Load models on startup
187
- print("=" * 70)
188
- print(" Batik Classifier API - MobileNetV2")
189
- print("=" * 70)
190
 
 
191
  if load_models():
192
- print("=" * 70)
193
- print(" All models loaded successfully!")
194
- print(" Ready to classify batik patterns")
195
- print("=" * 70)
196
  else:
197
- print("=" * 70)
198
- print(" Failed to load models")
199
- print("=" * 70)
200
 
201
  if __name__ == '__main__':
202
  app.run(host='0.0.0.0', port=7860)
 
1
  """
2
  Batik Classifier API - MobileNetV2 Model
3
  95.43% accuracy on 42 batik classes
 
4
  """
5
 
6
  import os
 
11
  from flask_cors import CORS
12
  from PIL import Image
13
 
 
14
  os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 
 
15
  import tensorflow as tf
 
 
16
 
17
  app = Flask(__name__)
18
  CORS(app)
19
 
 
20
  model = None
21
  class_names = None
22
  config = None
23
 
24
+ def preprocess_mobilenet(x):
25
+ x = x / 127.5 - 1.0
26
+ return x
27
+
28
  def load_models():
 
29
  global model, class_names, config
30
 
 
 
31
  try:
32
+ # Load with safe_mode=False
33
+ model_path = "models/batik_model.keras"
34
+ model = tf.keras.models.load_model(model_path, compile=False, safe_mode=False)
35
+ print(f"Model loaded: {model.input_shape} -> {model.output_shape}")
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ with open("models/batik_classes.json") as f:
 
 
38
  class_names = json.load(f)
39
+ print(f"Loaded {len(class_names)} classes")
40
 
41
+ try:
42
+ with open("models/batik_config.json") as f:
43
+ config = json.load(f)
44
+ except:
45
+ config = {"model": "MobileNetV2", "val_accuracy": 0.9543, "epochs": 50}
 
 
46
 
47
  return True
 
48
  except Exception as e:
49
+ print(f"Error: {e}")
 
 
50
  return False
51
 
52
+ def preprocess_image(image):
 
53
  if image.mode != 'RGB':
54
  image = image.convert('RGB')
55
+ image = image.resize((160, 160), Image.Resampling.LANCZOS)
 
 
56
  img_array = np.array(image, dtype=np.float32)
57
  img_array = np.expand_dims(img_array, axis=0)
58
+ return preprocess_mobilenet(img_array)
 
 
59
 
60
+ @app.route('/')
61
  def index():
 
62
  return jsonify({
63
  "name": "Batik Classifier API",
64
  "model": "MobileNetV2",
 
65
  "classes": len(class_names) if class_names else 0,
66
  "accuracy": config.get('val_accuracy', 0) if config else 0,
 
 
67
  "input_size": "160x160",
68
+ "status": "ready" if model else "error"
 
 
 
 
 
 
69
  })
70
 
71
+ @app.route('/health')
72
  def health():
73
+ return jsonify({"status": "healthy" if model else "unhealthy"})
 
 
 
 
 
 
74
 
75
+ @app.route('/classes')
76
  def get_classes():
77
+ if not class_names:
78
+ return jsonify({"error": "Not loaded"}), 500
79
+ return jsonify({"classes": class_names, "total": len(class_names)})
 
 
 
 
 
80
 
81
+ @app.route('/info')
82
  def get_info():
83
+ return jsonify(config if config else {})
 
 
 
 
84
 
85
  @app.route('/predict', methods=['POST'])
86
  def predict():
87
+ if not model or not class_names:
 
 
88
  return jsonify({"error": "Model not loaded"}), 500
89
 
 
90
  if 'image' not in request.files:
91
+ return jsonify({"error": "No image"}), 400
92
 
93
  try:
94
+ image = Image.open(io.BytesIO(request.files['image'].read()))
95
+ processed = preprocess_image(image)
 
 
96
 
97
+ preds = model.predict(processed, verbose=0)
98
+ idx = np.argmax(preds[0])
99
+ conf = float(preds[0][idx])
 
 
100
 
101
+ top5 = np.argsort(preds[0])[-5:][::-1]
102
+ top5_preds = [{"class": class_names[i], "confidence": float(preds[0][i])} for i in top5]
 
 
 
 
 
 
 
103
 
104
  return jsonify({
105
+ "predicted_class": class_names[idx],
106
+ "confidence": conf,
107
+ "top5_predictions": top5_preds
 
108
  })
 
109
  except Exception as e:
110
+ return jsonify({"error": str(e)}), 500
 
 
 
 
 
 
 
 
 
111
 
112
+ print("Loading MobileNetV2 model...")
113
  if load_models():
114
+ print("Ready!")
 
 
 
115
  else:
116
+ print("Failed to load")
 
 
117
 
118
  if __name__ == '__main__':
119
  app.run(host='0.0.0.0', port=7860)