| """
|
| Complete Workflow Script
|
| Creates dataset, trains model, and uploads to Hugging Face
|
| """
|
|
|
| import subprocess
|
| import sys
|
| from pathlib import Path
|
|
|
|
|
| def print_banner(text):
|
| print("\n" + "="*60)
|
| print(text.center(60))
|
| print("="*60 + "\n")
|
|
|
|
|
| def check_dataset():
|
| """Check if dataset exists"""
|
| dataset_dir = Path("./dataset")
|
|
|
| if not dataset_dir.exists():
|
| print("β Dataset directory not found!")
|
| return False
|
|
|
|
|
| images = list(dataset_dir.glob("*.jpg")) + list(dataset_dir.glob("*.png"))
|
|
|
| if len(images) == 0:
|
| print("β No images found in dataset directory!")
|
| return False
|
|
|
| print(f"β Dataset found with {len(images)} images")
|
| return True
|
|
|
|
|
| def create_dataset():
|
| """Create test dataset"""
|
| print_banner("STEP 1: CREATE DATASET")
|
|
|
| if check_dataset():
|
| print("\nDataset already exists!")
|
| response = input("Overwrite? (y/n): ").strip().lower()
|
| if response != 'y':
|
| print("Skipping dataset creation...")
|
| return True
|
|
|
| print("\nCreating test dataset...")
|
|
|
| try:
|
| from create_test_dataset import create_test_dataset
|
| create_test_dataset(output_dir="./dataset", num_images=10)
|
| print("\nβ
Dataset created successfully!")
|
| return True
|
| except Exception as e:
|
| print(f"\nβ Error creating dataset: {e}")
|
| import traceback
|
| traceback.print_exc()
|
| return False
|
|
|
|
|
| def train_model():
|
| """Train the model"""
|
| print_banner("STEP 2: TRAIN MODEL")
|
|
|
| print("\nStarting training...")
|
| print("This may take several minutes depending on your CPU...\n")
|
|
|
| try:
|
|
|
| result = subprocess.run([
|
| sys.executable, "train.py",
|
| "--train_data", "./dataset",
|
| "--output_dir", "./models/bytedream",
|
| "--device", "cpu"
|
| ], check=True)
|
|
|
| print("\nβ
Training completed successfully!")
|
| return True
|
|
|
| except subprocess.CalledProcessError as e:
|
| print(f"\nβ Training failed: {e}")
|
| return False
|
| except Exception as e:
|
| print(f"\nβ Error: {e}")
|
| import traceback
|
| traceback.print_exc()
|
| return False
|
|
|
|
|
| def upload_to_hf(token, repo_id):
|
| """Upload to Hugging Face"""
|
| print_banner("STEP 3: UPLOAD TO HUGGING FACE")
|
|
|
|
|
| model_dir = Path("./models/bytedream")
|
| if not model_dir.exists():
|
| print("β Model directory not found!")
|
| print("Please train the model first.")
|
| return False
|
|
|
| print(f"\nUploading to {repo_id}...")
|
|
|
| try:
|
| from bytedream.generator import ByteDreamGenerator
|
|
|
|
|
| print("\nLoading model...")
|
| generator = ByteDreamGenerator(
|
| model_path="./models/bytedream",
|
| config_path="config.yaml",
|
| device="cpu",
|
| )
|
|
|
|
|
| generator.push_to_hub(
|
| repo_id=repo_id,
|
| token=token,
|
| private=False,
|
| commit_message="Upload Byte Dream model",
|
| )
|
|
|
| print("\nβ
Upload successful!")
|
| print(f"\nπ¦ Your model is available at:")
|
| print(f"https://huggingface.co/{repo_id}")
|
| return True
|
|
|
| except Exception as e:
|
| print(f"\nβ Upload failed: {e}")
|
| import traceback
|
| traceback.print_exc()
|
| return False
|
|
|
|
|
| def main():
|
| """Main workflow"""
|
| print_banner("BYTE DREAM - COMPLETE WORKFLOW")
|
|
|
| print("This script will:")
|
| print("1. Create test dataset")
|
| print("2. Train the model")
|
| print("3. Upload to Hugging Face")
|
| print()
|
|
|
|
|
| if not create_dataset():
|
| print("\nβ Failed to create dataset. Exiting...")
|
| return
|
|
|
|
|
| print("\nReady to train the model.")
|
| response = input("Continue to training? (y/n): ").strip().lower()
|
| if response != 'y':
|
| print("Training cancelled.")
|
| return
|
|
|
| if not train_model():
|
| print("\nβ Training failed. Exiting...")
|
| return
|
|
|
|
|
| print("\nReady to upload to Hugging Face.")
|
| response = input("Continue to upload? (y/n): ").strip().lower()
|
| if response != 'y':
|
| print("Upload cancelled.")
|
| print("\nModel saved to: ./models/bytedream")
|
| print("To upload later: python publish_to_hf.py")
|
| return
|
|
|
|
|
| token = input("\nEnter your Hugging Face token (hf_...): ").strip()
|
| if not token:
|
| print("β Token required!")
|
| return
|
|
|
| repo_id = input("Enter repository ID (e.g., Enzo8930302/ByteDream): ").strip()
|
| if not repo_id:
|
| print("β Repository ID required!")
|
| return
|
|
|
| if not upload_to_hf(token, repo_id):
|
| print("\nβ Upload failed.")
|
| return
|
|
|
|
|
| print_banner("WORKFLOW COMPLETED SUCCESSFULLY!")
|
| print("β
Dataset created")
|
| print("β
Model trained")
|
| print(f"β
Uploaded to Hugging Face: {repo_id}")
|
| print()
|
|
|
|
|
| if __name__ == "__main__":
|
| try:
|
| main()
|
| except KeyboardInterrupt:
|
| print("\n\nWorkflow interrupted!")
|
| sys.exit(0)
|
|
|