| """
|
| Quick Start Script for Hugging Face Integration
|
| Helps you upload/download models easily
|
| """
|
|
|
| import sys
|
| from pathlib import Path
|
|
|
|
|
| def print_banner(text):
|
| """Print formatted banner"""
|
| print("\n" + "="*60)
|
| print(text.center(60))
|
| print("="*60 + "\n")
|
|
|
|
|
| def check_model_exists():
|
| """Check if trained model exists"""
|
| model_path = Path("./models/bytedream")
|
|
|
| if not model_path.exists():
|
| print("❌ Model directory not found!")
|
| print("\nPlease train the model first:")
|
| print(" python train.py")
|
| print("\nOr download from Hugging Face:")
|
| print(" python infer.py --hf_repo username/repo --prompt 'test'")
|
| return False
|
|
|
|
|
| unet_weights = model_path / "unet" / "pytorch_model.bin"
|
| vae_weights = model_path / "vae" / "pytorch_model.bin"
|
|
|
| if not (unet_weights.exists() or (model_path / "pytorch_model.bin").exists()):
|
| print("⚠ Model directory exists but no weights found!")
|
| print("Please train the model first.")
|
| return False
|
|
|
| return True
|
|
|
|
|
| def upload_to_hf():
|
| """Upload model to Hugging Face"""
|
| print_banner("UPLOAD TO HUGGING FACE HUB")
|
|
|
|
|
| if not check_model_exists():
|
| return
|
|
|
|
|
| token = input("Enter your Hugging Face token (hf_...): ").strip()
|
| if not token:
|
| print("❌ Token is required!")
|
| return
|
|
|
|
|
| repo_id = input("Enter repository ID (e.g., username/ByteDream): ").strip()
|
| if not repo_id:
|
| print("❌ Repository ID is required!")
|
| return
|
|
|
| print(f"\n📤 Uploading to {repo_id}...")
|
|
|
| try:
|
| from bytedream.generator import ByteDreamGenerator
|
|
|
|
|
| print("\nLoading model...")
|
| generator = ByteDreamGenerator(
|
| model_path="./models/bytedream",
|
| config_path="config.yaml",
|
| device="cpu",
|
| )
|
|
|
|
|
| generator.push_to_hub(
|
| repo_id=repo_id,
|
| token=token,
|
| private=False,
|
| commit_message="Upload Byte Dream model",
|
| )
|
|
|
| print("\n✅ SUCCESS!")
|
| print(f"\n📦 Your model is available at:")
|
| print(f"https://huggingface.co/{repo_id}")
|
| print(f"\nTo use this model:")
|
| print(f" python infer.py --prompt 'your prompt' --hf_repo '{repo_id}'")
|
| print("="*60)
|
|
|
| except Exception as e:
|
| print(f"\n❌ Error: {e}")
|
| import traceback
|
| traceback.print_exc()
|
|
|
|
|
| def download_from_hf():
|
| """Download model from Hugging Face"""
|
| print_banner("DOWNLOAD FROM HUGGING FACE HUB")
|
|
|
|
|
| repo_id = input("Enter repository ID (e.g., username/ByteDream): ").strip()
|
| if not repo_id:
|
| print("❌ Repository ID is required!")
|
| return
|
|
|
| print(f"\n📥 Downloading from {repo_id}...")
|
|
|
| try:
|
| from bytedream.generator import ByteDreamGenerator
|
|
|
|
|
| generator = ByteDreamGenerator(
|
| hf_repo_id=repo_id,
|
| config_path="config.yaml",
|
| device="cpu",
|
| )
|
|
|
| print("\n✅ Model loaded successfully!")
|
|
|
|
|
| test = input("\nGenerate test image? (y/n): ").strip().lower()
|
| if test == 'y':
|
| print("\nGenerating test image...")
|
| image = generator.generate(
|
| prompt="test pattern, simple colors",
|
| width=256,
|
| height=256,
|
| num_inference_steps=10,
|
| )
|
|
|
| output = "test_output.png"
|
| image.save(output)
|
| print(f"✓ Test image saved to: {output}")
|
|
|
| print("\nTo generate images:")
|
| print(f" python infer.py --prompt 'your prompt' --hf_repo '{repo_id}'")
|
| print(f" HF_REPO_ID={repo_id} python app.py")
|
| print("="*60)
|
|
|
| except Exception as e:
|
| print(f"\n❌ Error: {e}")
|
| import traceback
|
| traceback.print_exc()
|
|
|
|
|
| def test_local_model():
|
| """Test local model"""
|
| print_banner("TEST LOCAL MODEL")
|
|
|
| if not check_model_exists():
|
| return
|
|
|
| print("Loading local model...")
|
|
|
| try:
|
| from bytedream.generator import ByteDreamGenerator
|
|
|
| generator = ByteDreamGenerator(
|
| model_path="./models/bytedream",
|
| config_path="config.yaml",
|
| device="cpu",
|
| )
|
|
|
| print("\n✅ Model loaded successfully!")
|
|
|
|
|
| print("\nGenerating test image...")
|
| image = generator.generate(
|
| prompt="test pattern, simple colors",
|
| width=256,
|
| height=256,
|
| num_inference_steps=10,
|
| )
|
|
|
| output = "test_output.png"
|
| image.save(output)
|
| print(f"✓ Test image saved to: {output}")
|
|
|
| print("\nModel ready for upload!")
|
| print("To upload: python quick_start.py upload")
|
| print("="*60)
|
|
|
| except Exception as e:
|
| print(f"\n❌ Error: {e}")
|
| import traceback
|
| traceback.print_exc()
|
|
|
|
|
| def main():
|
| """Main function"""
|
| print_banner("BYTE DREAM - QUICK START")
|
|
|
| print("What would you like to do?")
|
| print("1. Upload model to Hugging Face")
|
| print("2. Download model from Hugging Face")
|
| print("3. Test local model")
|
| print("4. Exit")
|
| print()
|
|
|
| choice = input("Enter choice (1-4): ").strip()
|
|
|
| if choice == "1":
|
| upload_to_hf()
|
| elif choice == "2":
|
| download_from_hf()
|
| elif choice == "3":
|
| test_local_model()
|
| elif choice == "4":
|
| print("\nGoodbye!")
|
| return
|
| else:
|
| print("❌ Invalid choice!")
|
| return
|
|
|
| print("\nDone!")
|
|
|
|
|
| if __name__ == "__main__":
|
| try:
|
| main()
|
| except KeyboardInterrupt:
|
| print("\n\nInterrupted!")
|
| sys.exit(0)
|
|
|