ameenmarashi commited on
Commit
414eaf1
·
verified ·
1 Parent(s): d297a44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -20
app.py CHANGED
@@ -23,7 +23,7 @@ def train_and_export(dataset_file, model_name, num_classes, epochs, batch_size,
23
  zip_path = os.path.join(UPLOAD_DIR, f"{uid}.zip")
24
 
25
  # Copy uploaded file to our storage
26
- shutil.copyfile(dataset_file, zip_path)
27
 
28
  # Extract dataset
29
  extract_path = os.path.join(UPLOAD_DIR, uid)
@@ -37,7 +37,7 @@ def train_and_export(dataset_file, model_name, num_classes, epochs, batch_size,
37
 
38
  # Verify dataset structure
39
  if not os.path.exists(train_dir) or not os.path.exists(val_dir):
40
- return "Error: Dataset must contain 'train' and 'validation' folders"
41
 
42
  # Create data generators
43
  train_datagen = ImageDataGenerator(
@@ -151,20 +151,19 @@ def train_and_export(dataset_file, model_name, num_classes, epochs, batch_size,
151
  model_size += os.path.getsize(fp)
152
  model_size_mb = model_size / (1024 * 1024)
153
 
154
- # Get class names
155
- class_names = list(train_generator.class_indices.keys())
156
-
157
- # Prepare download links
158
- download_info = f"""
159
  ✅ Training completed successfully!
160
  ⏱️ Training time: {training_time:.2f} seconds
161
- 📊 Validation accuracy: {max(history.history['val_accuracy']):.4f}
162
  📦 Model size: {model_size_mb:.2f} MB
163
  🗂️ Number of classes: {num_classes}
 
 
164
  """
165
 
166
  # Return paths for download
167
- return download_info, h5_path, savedmodel_path, tfjs_path
168
 
169
  except Exception as e:
170
  return f"❌ Training failed: {str(e)}", None, None, None
@@ -172,7 +171,10 @@ def train_and_export(dataset_file, model_name, num_classes, epochs, batch_size,
172
  # Gradio interface
173
  with gr.Blocks(title="AI Image Classifier Trainer") as demo:
174
  gr.Markdown("# 🖼️ AI Image Classifier Trainer")
175
- gr.Markdown("Upload your dataset (ZIP with train/validation folders), configure training, and download models in multiple formats.")
 
 
 
176
 
177
  with gr.Row():
178
  with gr.Column():
@@ -186,16 +188,24 @@ with gr.Blocks(title="AI Image Classifier Trainer") as demo:
186
 
187
  with gr.Column():
188
  output = gr.Textbox(label="Training Results", interactive=False)
189
- h5_download = gr.File(label="H5 Model Download", visible=False)
190
- savedmodel_download = gr.File(label="SavedModel Download", visible=False)
191
- tfjs_download = gr.File(label="TensorFlow.js Download", visible=False)
 
192
 
193
- def toggle_downloads(results, h5_path, saved_path, tfjs_path):
194
- downloads_visible = h5_path is not None
 
 
 
 
 
 
195
  return (
196
- gr.File(visible=downloads_visible, value=h5_path),
197
- gr.File(visible=downloads_visible, value=saved_path),
198
- gr.File(visible=downloads_visible, value=tfjs_path)
 
199
  )
200
 
201
  train_btn.click(
@@ -205,7 +215,7 @@ with gr.Blocks(title="AI Image Classifier Trainer") as demo:
205
  ).then(
206
  fn=toggle_downloads,
207
  inputs=[output, h5_download, savedmodel_download, tfjs_download],
208
- outputs=[h5_download, savedmodel_download, tfjs_download]
209
  )
210
 
211
  # Launch settings for Hugging Face Spaces
@@ -214,5 +224,5 @@ if __name__ == "__main__":
214
  server_name="0.0.0.0",
215
  server_port=7860,
216
  share=False,
217
- max_file_size=100 # 100MB file size limit
218
  )
 
23
  zip_path = os.path.join(UPLOAD_DIR, f"{uid}.zip")
24
 
25
  # Copy uploaded file to our storage
26
+ shutil.copyfile(dataset_file.name, zip_path)
27
 
28
  # Extract dataset
29
  extract_path = os.path.join(UPLOAD_DIR, uid)
 
37
 
38
  # Verify dataset structure
39
  if not os.path.exists(train_dir) or not os.path.exists(val_dir):
40
+ return "Error: Dataset must contain 'train' and 'validation' folders", None, None, None
41
 
42
  # Create data generators
43
  train_datagen = ImageDataGenerator(
 
151
  model_size += os.path.getsize(fp)
152
  model_size_mb = model_size / (1024 * 1024)
153
 
154
+ # Prepare results
155
+ result_text = f"""
 
 
 
156
  ✅ Training completed successfully!
157
  ⏱️ Training time: {training_time:.2f} seconds
158
+ 📊 Best validation accuracy: {max(history.history['val_accuracy']):.4f}
159
  📦 Model size: {model_size_mb:.2f} MB
160
  🗂️ Number of classes: {num_classes}
161
+
162
+ Download links available below ⬇️
163
  """
164
 
165
  # Return paths for download
166
+ return result_text, h5_path, savedmodel_path, tfjs_path
167
 
168
  except Exception as e:
169
  return f"❌ Training failed: {str(e)}", None, None, None
 
171
  # Gradio interface
172
  with gr.Blocks(title="AI Image Classifier Trainer") as demo:
173
  gr.Markdown("# 🖼️ AI Image Classifier Trainer")
174
+ gr.Markdown("""
175
+ Upload your dataset (ZIP file containing `train/` and `validation/` folders),
176
+ configure training parameters, and download models in multiple formats.
177
+ """)
178
 
179
  with gr.Row():
180
  with gr.Column():
 
188
 
189
  with gr.Column():
190
  output = gr.Textbox(label="Training Results", interactive=False)
191
+ with gr.Column(visible=False) as download_col:
192
+ h5_download = gr.File(label="H5 Model Download")
193
+ savedmodel_download = gr.File(label="SavedModel Download")
194
+ tfjs_download = gr.File(label="TensorFlow.js Download")
195
 
196
+ def toggle_downloads(result, h5_path, saved_path, tfjs_path):
197
+ if h5_path:
198
+ return (
199
+ gr.Column(visible=True),
200
+ gr.File(value=h5_path),
201
+ gr.File(value=saved_path),
202
+ gr.File(value=tfjs_path)
203
+ )
204
  return (
205
+ gr.Column(visible=False),
206
+ gr.File(value=None),
207
+ gr.File(value=None),
208
+ gr.File(value=None)
209
  )
210
 
211
  train_btn.click(
 
215
  ).then(
216
  fn=toggle_downloads,
217
  inputs=[output, h5_download, savedmodel_download, tfjs_download],
218
+ outputs=[download_col, h5_download, savedmodel_download, tfjs_download]
219
  )
220
 
221
  # Launch settings for Hugging Face Spaces
 
224
  server_name="0.0.0.0",
225
  server_port=7860,
226
  share=False,
227
+ max_file_size="100mb" # Allows 100MB file uploads
228
  )