ByteDream / quick_start.py
Enzo8930302's picture
Upload quick_start.py with huggingface_hub
165a196 verified
"""
Quick Start Script for Hugging Face Integration
Helps you upload/download models easily
"""
import sys
from pathlib import Path
def print_banner(text):
"""Print formatted banner"""
print("\n" + "="*60)
print(text.center(60))
print("="*60 + "\n")
def check_model_exists():
"""Check if trained model exists"""
model_path = Path("./models/bytedream")
if not model_path.exists():
print("❌ Model directory not found!")
print("\nPlease train the model first:")
print(" python train.py")
print("\nOr download from Hugging Face:")
print(" python infer.py --hf_repo username/repo --prompt 'test'")
return False
# Check for weights
unet_weights = model_path / "unet" / "pytorch_model.bin"
vae_weights = model_path / "vae" / "pytorch_model.bin"
if not (unet_weights.exists() or (model_path / "pytorch_model.bin").exists()):
print("⚠ Model directory exists but no weights found!")
print("Please train the model first.")
return False
return True
def upload_to_hf():
"""Upload model to Hugging Face"""
print_banner("UPLOAD TO HUGGING FACE HUB")
# Check model exists
if not check_model_exists():
return
# Get token
token = input("Enter your Hugging Face token (hf_...): ").strip()
if not token:
print("❌ Token is required!")
return
# Get repo ID
repo_id = input("Enter repository ID (e.g., username/ByteDream): ").strip()
if not repo_id:
print("❌ Repository ID is required!")
return
print(f"\n📤 Uploading to {repo_id}...")
try:
from bytedream.generator import ByteDreamGenerator
# Load model
print("\nLoading model...")
generator = ByteDreamGenerator(
model_path="./models/bytedream",
config_path="config.yaml",
device="cpu",
)
# Upload
generator.push_to_hub(
repo_id=repo_id,
token=token,
private=False,
commit_message="Upload Byte Dream model",
)
print("\n✅ SUCCESS!")
print(f"\n📦 Your model is available at:")
print(f"https://huggingface.co/{repo_id}")
print(f"\nTo use this model:")
print(f" python infer.py --prompt 'your prompt' --hf_repo '{repo_id}'")
print("="*60)
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
def download_from_hf():
"""Download model from Hugging Face"""
print_banner("DOWNLOAD FROM HUGGING FACE HUB")
# Get repo ID
repo_id = input("Enter repository ID (e.g., username/ByteDream): ").strip()
if not repo_id:
print("❌ Repository ID is required!")
return
print(f"\n📥 Downloading from {repo_id}...")
try:
from bytedream.generator import ByteDreamGenerator
# Load from HF
generator = ByteDreamGenerator(
hf_repo_id=repo_id,
config_path="config.yaml",
device="cpu",
)
print("\n✅ Model loaded successfully!")
# Test generation
test = input("\nGenerate test image? (y/n): ").strip().lower()
if test == 'y':
print("\nGenerating test image...")
image = generator.generate(
prompt="test pattern, simple colors",
width=256,
height=256,
num_inference_steps=10,
)
output = "test_output.png"
image.save(output)
print(f"✓ Test image saved to: {output}")
print("\nTo generate images:")
print(f" python infer.py --prompt 'your prompt' --hf_repo '{repo_id}'")
print(f" HF_REPO_ID={repo_id} python app.py")
print("="*60)
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
def test_local_model():
"""Test local model"""
print_banner("TEST LOCAL MODEL")
if not check_model_exists():
return
print("Loading local model...")
try:
from bytedream.generator import ByteDreamGenerator
generator = ByteDreamGenerator(
model_path="./models/bytedream",
config_path="config.yaml",
device="cpu",
)
print("\n✅ Model loaded successfully!")
# Generate test image
print("\nGenerating test image...")
image = generator.generate(
prompt="test pattern, simple colors",
width=256,
height=256,
num_inference_steps=10,
)
output = "test_output.png"
image.save(output)
print(f"✓ Test image saved to: {output}")
print("\nModel ready for upload!")
print("To upload: python quick_start.py upload")
print("="*60)
except Exception as e:
print(f"\n❌ Error: {e}")
import traceback
traceback.print_exc()
def main():
"""Main function"""
print_banner("BYTE DREAM - QUICK START")
print("What would you like to do?")
print("1. Upload model to Hugging Face")
print("2. Download model from Hugging Face")
print("3. Test local model")
print("4. Exit")
print()
choice = input("Enter choice (1-4): ").strip()
if choice == "1":
upload_to_hf()
elif choice == "2":
download_from_hf()
elif choice == "3":
test_local_model()
elif choice == "4":
print("\nGoodbye!")
return
else:
print("❌ Invalid choice!")
return
print("\nDone!")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("\n\nInterrupted!")
sys.exit(0)