ByteDream / workflow_complete.py
Enzo8930302's picture
Upload workflow_complete.py with huggingface_hub
d3516f4 verified
"""
Complete Workflow Script
Creates dataset, trains model, and uploads to Hugging Face
"""
import subprocess
import sys
from pathlib import Path
def print_banner(text):
print("\n" + "="*60)
print(text.center(60))
print("="*60 + "\n")
def check_dataset():
"""Check if dataset exists"""
dataset_dir = Path("./dataset")
if not dataset_dir.exists():
print("❌ Dataset directory not found!")
return False
# Check for images
images = list(dataset_dir.glob("*.jpg")) + list(dataset_dir.glob("*.png"))
if len(images) == 0:
print("⚠ No images found in dataset directory!")
return False
print(f"βœ“ Dataset found with {len(images)} images")
return True
def create_dataset():
"""Create test dataset"""
print_banner("STEP 1: CREATE DATASET")
if check_dataset():
print("\nDataset already exists!")
response = input("Overwrite? (y/n): ").strip().lower()
if response != 'y':
print("Skipping dataset creation...")
return True
print("\nCreating test dataset...")
try:
from create_test_dataset import create_test_dataset
create_test_dataset(output_dir="./dataset", num_images=10)
print("\nβœ… Dataset created successfully!")
return True
except Exception as e:
print(f"\n❌ Error creating dataset: {e}")
import traceback
traceback.print_exc()
return False
def train_model():
"""Train the model"""
print_banner("STEP 2: TRAIN MODEL")
print("\nStarting training...")
print("This may take several minutes depending on your CPU...\n")
try:
# Run training
result = subprocess.run([
sys.executable, "train.py",
"--train_data", "./dataset",
"--output_dir", "./models/bytedream",
"--device", "cpu"
], check=True)
print("\nβœ… Training completed successfully!")
return True
except subprocess.CalledProcessError as e:
print(f"\n❌ Training failed: {e}")
return False
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
return False
def upload_to_hf(token, repo_id):
"""Upload to Hugging Face"""
print_banner("STEP 3: UPLOAD TO HUGGING FACE")
# Check if model exists
model_dir = Path("./models/bytedream")
if not model_dir.exists():
print("❌ Model directory not found!")
print("Please train the model first.")
return False
print(f"\nUploading to {repo_id}...")
try:
from bytedream.generator import ByteDreamGenerator
# Load generator
print("\nLoading model...")
generator = ByteDreamGenerator(
model_path="./models/bytedream",
config_path="config.yaml",
device="cpu",
)
# Upload to HF
generator.push_to_hub(
repo_id=repo_id,
token=token,
private=False,
commit_message="Upload Byte Dream model",
)
print("\nβœ… Upload successful!")
print(f"\nπŸ“¦ Your model is available at:")
print(f"https://huggingface.co/{repo_id}")
return True
except Exception as e:
print(f"\n❌ Upload failed: {e}")
import traceback
traceback.print_exc()
return False
def main():
"""Main workflow"""
print_banner("BYTE DREAM - COMPLETE WORKFLOW")
print("This script will:")
print("1. Create test dataset")
print("2. Train the model")
print("3. Upload to Hugging Face")
print()
# Step 1: Create dataset
if not create_dataset():
print("\n❌ Failed to create dataset. Exiting...")
return
# Step 2: Train
print("\nReady to train the model.")
response = input("Continue to training? (y/n): ").strip().lower()
if response != 'y':
print("Training cancelled.")
return
if not train_model():
print("\n❌ Training failed. Exiting...")
return
# Step 3: Upload to HF
print("\nReady to upload to Hugging Face.")
response = input("Continue to upload? (y/n): ").strip().lower()
if response != 'y':
print("Upload cancelled.")
print("\nModel saved to: ./models/bytedream")
print("To upload later: python publish_to_hf.py")
return
# Get HF credentials
token = input("\nEnter your Hugging Face token (hf_...): ").strip()
if not token:
print("❌ Token required!")
return
repo_id = input("Enter repository ID (e.g., Enzo8930302/ByteDream): ").strip()
if not repo_id:
print("❌ Repository ID required!")
return
if not upload_to_hf(token, repo_id):
print("\n❌ Upload failed.")
return
# Success!
print_banner("WORKFLOW COMPLETED SUCCESSFULLY!")
print("βœ… Dataset created")
print("βœ… Model trained")
print(f"βœ… Uploaded to Hugging Face: {repo_id}")
print()
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\nWorkflow interrupted!")
sys.exit(0)