Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import zipfile | |
| import os | |
| import uuid | |
| import shutil | |
| import subprocess | |
| import sys | |
| import time | |
| import tensorflow as tf | |
| from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
| import numpy as np | |
| # Directory setup | |
| UPLOAD_DIR = "uploads" | |
| MODEL_DIR = "models" | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| os.makedirs(MODEL_DIR, exist_ok=True) | |
| def train_and_export(dataset_file, model_name, num_classes, epochs, batch_size, image_size): | |
| try: | |
| # Generate unique ID for this training session | |
| uid = str(uuid.uuid4()) | |
| zip_path = os.path.join(UPLOAD_DIR, f"{uid}.zip") | |
| # Copy uploaded file to our storage | |
| shutil.copyfile(dataset_file.name, zip_path) | |
| # Extract dataset | |
| extract_path = os.path.join(UPLOAD_DIR, uid) | |
| os.makedirs(extract_path, exist_ok=True) | |
| with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
| zip_ref.extractall(extract_path) | |
| # Locate train and validation directories | |
| train_dir = os.path.join(extract_path, "train") | |
| val_dir = os.path.join(extract_path, "validation") | |
| # Verify dataset structure | |
| if not os.path.exists(train_dir) or not os.path.exists(val_dir): | |
| return "Error: Dataset must contain 'train' and 'validation' folders", None, None, None | |
| # Create data generators | |
| train_datagen = ImageDataGenerator( | |
| rescale=1./255, | |
| rotation_range=20, | |
| width_shift_range=0.2, | |
| height_shift_range=0.2, | |
| horizontal_flip=True, | |
| zoom_range=0.2 | |
| ) | |
| val_datagen = ImageDataGenerator(rescale=1./255) | |
| train_generator = train_datagen.flow_from_directory( | |
| train_dir, | |
| target_size=(image_size, image_size), | |
| batch_size=batch_size, | |
| class_mode='categorical' | |
| ) | |
| val_generator = val_datagen.flow_from_directory( | |
| val_dir, | |
| target_size=(image_size, image_size), | |
| batch_size=batch_size, | |
| class_mode='categorical' | |
| ) | |
| # Update num_classes based on actual data | |
| actual_classes = train_generator.num_classes | |
| if actual_classes != num_classes: | |
| num_classes = actual_classes | |
| # Build model | |
| model = tf.keras.Sequential([ | |
| tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(image_size, image_size, 3)), | |
| tf.keras.layers.BatchNormalization(), | |
| tf.keras.layers.MaxPooling2D(), | |
| tf.keras.layers.Dropout(0.25), | |
| tf.keras.layers.Conv2D(64, 3, activation='relu'), | |
| tf.keras.layers.BatchNormalization(), | |
| tf.keras.layers.MaxPooling2D(), | |
| tf.keras.layers.Dropout(0.25), | |
| tf.keras.layers.Conv2D(128, 3, activation='relu'), | |
| tf.keras.layers.BatchNormalization(), | |
| tf.keras.layers.MaxPooling2D(), | |
| tf.keras.layers.Dropout(0.25), | |
| tf.keras.layers.Flatten(), | |
| tf.keras.layers.Dense(256, activation='relu'), | |
| tf.keras.layers.BatchNormalization(), | |
| tf.keras.layers.Dropout(0.5), | |
| tf.keras.layers.Dense(num_classes, activation='softmax') | |
| ]) | |
| model.compile( | |
| optimizer='adam', | |
| loss='categorical_crossentropy', | |
| metrics=['accuracy'] | |
| ) | |
| # Train model | |
| start_time = time.time() | |
| history = model.fit( | |
| train_generator, | |
| steps_per_epoch=train_generator.samples // train_generator.batch_size, | |
| epochs=epochs, | |
| validation_data=val_generator, | |
| validation_steps=val_generator.samples // val_generator.batch_size, | |
| verbose=0 | |
| ) | |
| training_time = time.time() - start_time | |
| # Save models | |
| model_dir = os.path.join(MODEL_DIR, uid) | |
| os.makedirs(model_dir, exist_ok=True) | |
| # Save H5 model | |
| h5_path = os.path.join(model_dir, f"{model_name}.h5") | |
| model.save(h5_path) | |
| # Save SavedModel | |
| savedmodel_path = os.path.join(model_dir, "savedmodel") | |
| model.save(savedmodel_path) | |
| # Convert to TensorFlow.js | |
| tfjs_path = os.path.join(model_dir, "tfjs") | |
| try: | |
| subprocess.run([ | |
| "tensorflowjs_converter", | |
| "--input_format=tf_saved_model", | |
| savedmodel_path, | |
| tfjs_path | |
| ], check=True) | |
| except Exception: | |
| # Install tensorflowjs if not available | |
| subprocess.run([sys.executable, "-m", "pip", "install", "tensorflowjs"], check=True) | |
| subprocess.run([ | |
| "tensorflowjs_converter", | |
| "--input_format=tf_saved_model", | |
| savedmodel_path, | |
| tfjs_path | |
| ], check=True) | |
| # Calculate model size | |
| model_size = 0 | |
| for dirpath, _, filenames in os.walk(model_dir): | |
| for f in filenames: | |
| fp = os.path.join(dirpath, f) | |
| model_size += os.path.getsize(fp) | |
| model_size_mb = model_size / (1024 * 1024) | |
| # Prepare results | |
| result_text = f""" | |
| ✅ Training completed successfully! | |
| ⏱️ Training time: {training_time:.2f} seconds | |
| 📊 Best validation accuracy: {max(history.history['val_accuracy']):.4f} | |
| 📦 Model size: {model_size_mb:.2f} MB | |
| 🗂️ Number of classes: {num_classes} | |
| Download links available below ⬇️ | |
| """ | |
| # Return paths for download | |
| return result_text, h5_path, savedmodel_path, tfjs_path | |
| except Exception as e: | |
| return f"❌ Training failed: {str(e)}", None, None, None | |
| # Gradio interface | |
| with gr.Blocks(title="AI Image Classifier Trainer") as demo: | |
| gr.Markdown("# 🖼️ AI Image Classifier Trainer") | |
| gr.Markdown(""" | |
| Upload your dataset (ZIP file containing `train/` and `validation/` folders), | |
| configure training parameters, and download models in multiple formats. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| dataset = gr.File(label="Dataset ZIP File", file_types=[".zip"]) | |
| model_name = gr.Textbox(label="Model Name", value="my_classifier") | |
| num_classes = gr.Slider(2, 100, value=5, step=1, label="Number of Classes") | |
| epochs = gr.Slider(5, 200, value=30, step=1, label="Training Epochs") | |
| batch_size = gr.Radio([16, 32, 64], value=32, label="Batch Size") | |
| image_size = gr.Radio([128, 224, 256], value=224, label="Image Size (px)") | |
| train_btn = gr.Button("🚀 Train Model", variant="primary") | |
| with gr.Column(): | |
| output = gr.Textbox(label="Training Results", interactive=False) | |
| with gr.Column(visible=False) as download_col: | |
| h5_download = gr.File(label="H5 Model Download") | |
| savedmodel_download = gr.File(label="SavedModel Download") | |
| tfjs_download = gr.File(label="TensorFlow.js Download") | |
| def toggle_downloads(result, h5_path, saved_path, tfjs_path): | |
| if h5_path: | |
| return ( | |
| gr.Column(visible=True), | |
| gr.File(value=h5_path), | |
| gr.File(value=saved_path), | |
| gr.File(value=tfjs_path) | |
| ) | |
| return ( | |
| gr.Column(visible=False), | |
| gr.File(value=None), | |
| gr.File(value=None), | |
| gr.File(value=None) | |
| ) | |
| train_btn.click( | |
| fn=train_and_export, | |
| inputs=[dataset, model_name, num_classes, epochs, batch_size, image_size], | |
| outputs=[output, h5_download, savedmodel_download, tfjs_download] | |
| ).then( | |
| fn=toggle_downloads, | |
| inputs=[output, h5_download, savedmodel_download, tfjs_download], | |
| outputs=[download_col, h5_download, savedmodel_download, tfjs_download] | |
| ) | |
| # Launch settings for Hugging Face Spaces | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| max_file_size="100mb" # Allows 100MB file uploads | |
| ) | |