""" Complete Workflow Script Creates dataset, trains model, and uploads to Hugging Face """ import subprocess import sys from pathlib import Path def print_banner(text): print("\n" + "="*60) print(text.center(60)) print("="*60 + "\n") def check_dataset(): """Check if dataset exists""" dataset_dir = Path("./dataset") if not dataset_dir.exists(): print("❌ Dataset directory not found!") return False # Check for images images = list(dataset_dir.glob("*.jpg")) + list(dataset_dir.glob("*.png")) if len(images) == 0: print("⚠ No images found in dataset directory!") return False print(f"✓ Dataset found with {len(images)} images") return True def create_dataset(): """Create test dataset""" print_banner("STEP 1: CREATE DATASET") if check_dataset(): print("\nDataset already exists!") response = input("Overwrite? (y/n): ").strip().lower() if response != 'y': print("Skipping dataset creation...") return True print("\nCreating test dataset...") try: from create_test_dataset import create_test_dataset create_test_dataset(output_dir="./dataset", num_images=10) print("\n✅ Dataset created successfully!") return True except Exception as e: print(f"\n❌ Error creating dataset: {e}") import traceback traceback.print_exc() return False def train_model(): """Train the model""" print_banner("STEP 2: TRAIN MODEL") print("\nStarting training...") print("This may take several minutes depending on your CPU...\n") try: # Run training result = subprocess.run([ sys.executable, "train.py", "--train_data", "./dataset", "--output_dir", "./models/bytedream", "--device", "cpu" ], check=True) print("\n✅ Training completed successfully!") return True except subprocess.CalledProcessError as e: print(f"\n❌ Training failed: {e}") return False except Exception as e: print(f"\n❌ Error: {e}") import traceback traceback.print_exc() return False def upload_to_hf(token, repo_id): """Upload to Hugging Face""" print_banner("STEP 3: UPLOAD TO HUGGING FACE") # Check if model exists model_dir = Path("./models/bytedream") if not model_dir.exists(): print("❌ Model directory not found!") print("Please train the model first.") return False print(f"\nUploading to {repo_id}...") try: from bytedream.generator import ByteDreamGenerator # Load generator print("\nLoading model...") generator = ByteDreamGenerator( model_path="./models/bytedream", config_path="config.yaml", device="cpu", ) # Upload to HF generator.push_to_hub( repo_id=repo_id, token=token, private=False, commit_message="Upload Byte Dream model", ) print("\n✅ Upload successful!") print(f"\n📦 Your model is available at:") print(f"https://huggingface.co/{repo_id}") return True except Exception as e: print(f"\n❌ Upload failed: {e}") import traceback traceback.print_exc() return False def main(): """Main workflow""" print_banner("BYTE DREAM - COMPLETE WORKFLOW") print("This script will:") print("1. Create test dataset") print("2. Train the model") print("3. Upload to Hugging Face") print() # Step 1: Create dataset if not create_dataset(): print("\n❌ Failed to create dataset. Exiting...") return # Step 2: Train print("\nReady to train the model.") response = input("Continue to training? (y/n): ").strip().lower() if response != 'y': print("Training cancelled.") return if not train_model(): print("\n❌ Training failed. Exiting...") return # Step 3: Upload to HF print("\nReady to upload to Hugging Face.") response = input("Continue to upload? (y/n): ").strip().lower() if response != 'y': print("Upload cancelled.") print("\nModel saved to: ./models/bytedream") print("To upload later: python publish_to_hf.py") return # Get HF credentials token = input("\nEnter your Hugging Face token (hf_...): ").strip() if not token: print("❌ Token required!") return repo_id = input("Enter repository ID (e.g., Enzo8930302/ByteDream): ").strip() if not repo_id: print("❌ Repository ID required!") return if not upload_to_hf(token, repo_id): print("\n❌ Upload failed.") return # Success! print_banner("WORKFLOW COMPLETED SUCCESSFULLY!") print("✅ Dataset created") print("✅ Model trained") print(f"✅ Uploaded to Hugging Face: {repo_id}") print() if __name__ == "__main__": try: main() except KeyboardInterrupt: print("\n\nWorkflow interrupted!") sys.exit(0)