""" Quick Start Script for Hugging Face Integration Helps you upload/download models easily """ import sys from pathlib import Path def print_banner(text): """Print formatted banner""" print("\n" + "="*60) print(text.center(60)) print("="*60 + "\n") def check_model_exists(): """Check if trained model exists""" model_path = Path("./models/bytedream") if not model_path.exists(): print("āŒ Model directory not found!") print("\nPlease train the model first:") print(" python train.py") print("\nOr download from Hugging Face:") print(" python infer.py --hf_repo username/repo --prompt 'test'") return False # Check for weights unet_weights = model_path / "unet" / "pytorch_model.bin" vae_weights = model_path / "vae" / "pytorch_model.bin" if not (unet_weights.exists() or (model_path / "pytorch_model.bin").exists()): print("⚠ Model directory exists but no weights found!") print("Please train the model first.") return False return True def upload_to_hf(): """Upload model to Hugging Face""" print_banner("UPLOAD TO HUGGING FACE HUB") # Check model exists if not check_model_exists(): return # Get token token = input("Enter your Hugging Face token (hf_...): ").strip() if not token: print("āŒ Token is required!") return # Get repo ID repo_id = input("Enter repository ID (e.g., username/ByteDream): ").strip() if not repo_id: print("āŒ Repository ID is required!") return print(f"\nšŸ“¤ Uploading to {repo_id}...") try: from bytedream.generator import ByteDreamGenerator # Load model print("\nLoading model...") generator = ByteDreamGenerator( model_path="./models/bytedream", config_path="config.yaml", device="cpu", ) # Upload generator.push_to_hub( repo_id=repo_id, token=token, private=False, commit_message="Upload Byte Dream model", ) print("\nāœ… SUCCESS!") print(f"\nšŸ“¦ Your model is available at:") print(f"https://huggingface.co/{repo_id}") print(f"\nTo use this model:") print(f" python infer.py --prompt 'your prompt' --hf_repo '{repo_id}'") print("="*60) except Exception as e: print(f"\nāŒ Error: {e}") import traceback traceback.print_exc() def download_from_hf(): """Download model from Hugging Face""" print_banner("DOWNLOAD FROM HUGGING FACE HUB") # Get repo ID repo_id = input("Enter repository ID (e.g., username/ByteDream): ").strip() if not repo_id: print("āŒ Repository ID is required!") return print(f"\nšŸ“„ Downloading from {repo_id}...") try: from bytedream.generator import ByteDreamGenerator # Load from HF generator = ByteDreamGenerator( hf_repo_id=repo_id, config_path="config.yaml", device="cpu", ) print("\nāœ… Model loaded successfully!") # Test generation test = input("\nGenerate test image? (y/n): ").strip().lower() if test == 'y': print("\nGenerating test image...") image = generator.generate( prompt="test pattern, simple colors", width=256, height=256, num_inference_steps=10, ) output = "test_output.png" image.save(output) print(f"āœ“ Test image saved to: {output}") print("\nTo generate images:") print(f" python infer.py --prompt 'your prompt' --hf_repo '{repo_id}'") print(f" HF_REPO_ID={repo_id} python app.py") print("="*60) except Exception as e: print(f"\nāŒ Error: {e}") import traceback traceback.print_exc() def test_local_model(): """Test local model""" print_banner("TEST LOCAL MODEL") if not check_model_exists(): return print("Loading local model...") try: from bytedream.generator import ByteDreamGenerator generator = ByteDreamGenerator( model_path="./models/bytedream", config_path="config.yaml", device="cpu", ) print("\nāœ… Model loaded successfully!") # Generate test image print("\nGenerating test image...") image = generator.generate( prompt="test pattern, simple colors", width=256, height=256, num_inference_steps=10, ) output = "test_output.png" image.save(output) print(f"āœ“ Test image saved to: {output}") print("\nModel ready for upload!") print("To upload: python quick_start.py upload") print("="*60) except Exception as e: print(f"\nāŒ Error: {e}") import traceback traceback.print_exc() def main(): """Main function""" print_banner("BYTE DREAM - QUICK START") print("What would you like to do?") print("1. Upload model to Hugging Face") print("2. Download model from Hugging Face") print("3. Test local model") print("4. Exit") print() choice = input("Enter choice (1-4): ").strip() if choice == "1": upload_to_hf() elif choice == "2": download_from_hf() elif choice == "3": test_local_model() elif choice == "4": print("\nGoodbye!") return else: print("āŒ Invalid choice!") return print("\nDone!") if __name__ == "__main__": try: main() except KeyboardInterrupt: print("\n\nInterrupted!") sys.exit(0)