Upload bytedream/generator.py with huggingface_hub
Browse files- bytedream/generator.py +97 -2
bytedream/generator.py
CHANGED
|
@@ -24,6 +24,7 @@ class ByteDreamGenerator:
|
|
| 24 |
config_path: str = "config.yaml",
|
| 25 |
device: str = "cpu",
|
| 26 |
use_safetensors: bool = True,
|
|
|
|
| 27 |
):
|
| 28 |
"""
|
| 29 |
Initialize Byte Dream generator
|
|
@@ -33,17 +34,19 @@ class ByteDreamGenerator:
|
|
| 33 |
config_path: Path to configuration file
|
| 34 |
device: Device to run on (default: cpu)
|
| 35 |
use_safetensors: Use safetensors format if available
|
|
|
|
| 36 |
"""
|
| 37 |
self.device = device
|
| 38 |
self.config_path = config_path
|
| 39 |
self.use_safetensors = use_safetensors
|
|
|
|
| 40 |
|
| 41 |
# Load configuration
|
| 42 |
self.config = self._load_config(config_path)
|
| 43 |
|
| 44 |
# Initialize components
|
| 45 |
print("Initializing Byte Dream Generator...")
|
| 46 |
-
self.pipeline = self._initialize_pipeline(model_path)
|
| 47 |
|
| 48 |
# Optimize for CPU
|
| 49 |
self._optimize_for_cpu()
|
|
@@ -92,12 +95,27 @@ class ByteDreamGenerator:
|
|
| 92 |
}
|
| 93 |
}
|
| 94 |
|
| 95 |
-
def _initialize_pipeline(self, model_path: Optional[str]):
|
| 96 |
"""Initialize the generation pipeline"""
|
| 97 |
from bytedream.model import create_unet, create_vae, create_text_encoder
|
| 98 |
from bytedream.scheduler import create_scheduler
|
| 99 |
from bytedream.pipeline import ByteDreamPipeline
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
# Create model components
|
| 102 |
print("Creating UNet...")
|
| 103 |
unet = create_unet(self.config)
|
|
@@ -315,3 +333,80 @@ class ByteDreamGenerator:
|
|
| 315 |
if torch.cuda.is_available():
|
| 316 |
torch.cuda.empty_cache()
|
| 317 |
print("Memory cleared")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
config_path: str = "config.yaml",
|
| 25 |
device: str = "cpu",
|
| 26 |
use_safetensors: bool = True,
|
| 27 |
+
hf_repo_id: Optional[str] = None,
|
| 28 |
):
|
| 29 |
"""
|
| 30 |
Initialize Byte Dream generator
|
|
|
|
| 34 |
config_path: Path to configuration file
|
| 35 |
device: Device to run on (default: cpu)
|
| 36 |
use_safetensors: Use safetensors format if available
|
| 37 |
+
hf_repo_id: Hugging Face repository ID (e.g., "username/repo")
|
| 38 |
"""
|
| 39 |
self.device = device
|
| 40 |
self.config_path = config_path
|
| 41 |
self.use_safetensors = use_safetensors
|
| 42 |
+
self.hf_repo_id = hf_repo_id
|
| 43 |
|
| 44 |
# Load configuration
|
| 45 |
self.config = self._load_config(config_path)
|
| 46 |
|
| 47 |
# Initialize components
|
| 48 |
print("Initializing Byte Dream Generator...")
|
| 49 |
+
self.pipeline = self._initialize_pipeline(model_path, hf_repo_id)
|
| 50 |
|
| 51 |
# Optimize for CPU
|
| 52 |
self._optimize_for_cpu()
|
|
|
|
| 95 |
}
|
| 96 |
}
|
| 97 |
|
| 98 |
+
def _initialize_pipeline(self, model_path: Optional[str], hf_repo_id: Optional[str] = None):
|
| 99 |
"""Initialize the generation pipeline"""
|
| 100 |
from bytedream.model import create_unet, create_vae, create_text_encoder
|
| 101 |
from bytedream.scheduler import create_scheduler
|
| 102 |
from bytedream.pipeline import ByteDreamPipeline
|
| 103 |
|
| 104 |
+
# If HF repo ID is provided, try to load from Hugging Face
|
| 105 |
+
if hf_repo_id is not None:
|
| 106 |
+
print(f"Loading model from Hugging Face: {hf_repo_id}...")
|
| 107 |
+
try:
|
| 108 |
+
from bytedream.pipeline import ByteDreamPipeline
|
| 109 |
+
pipeline = ByteDreamPipeline.from_pretrained(
|
| 110 |
+
hf_repo_id,
|
| 111 |
+
device=self.device,
|
| 112 |
+
dtype=torch.float32,
|
| 113 |
+
)
|
| 114 |
+
return pipeline
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print(f"Error loading from Hugging Face: {e}")
|
| 117 |
+
print("Falling back to local model...")
|
| 118 |
+
|
| 119 |
# Create model components
|
| 120 |
print("Creating UNet...")
|
| 121 |
unet = create_unet(self.config)
|
|
|
|
| 333 |
if torch.cuda.is_available():
|
| 334 |
torch.cuda.empty_cache()
|
| 335 |
print("Memory cleared")
|
| 336 |
+
|
| 337 |
+
def save_pretrained(self, save_directory: str):
|
| 338 |
+
"""
|
| 339 |
+
Save model to directory for Hugging Face upload
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
save_directory: Directory path to save models
|
| 343 |
+
"""
|
| 344 |
+
if self.pipeline is None:
|
| 345 |
+
raise ValueError("No pipeline initialized. Cannot save.")
|
| 346 |
+
|
| 347 |
+
return self.pipeline.save_pretrained(save_directory)
|
| 348 |
+
|
| 349 |
+
def push_to_hub(
|
| 350 |
+
self,
|
| 351 |
+
repo_id: str,
|
| 352 |
+
token: Optional[str] = None,
|
| 353 |
+
private: bool = False,
|
| 354 |
+
commit_message: str = "Upload Byte Dream model",
|
| 355 |
+
):
|
| 356 |
+
"""
|
| 357 |
+
Push model to Hugging Face Hub
|
| 358 |
+
|
| 359 |
+
Args:
|
| 360 |
+
repo_id: Repository ID (username/model-name)
|
| 361 |
+
token: Hugging Face API token
|
| 362 |
+
private: Whether to make repository private
|
| 363 |
+
commit_message: Commit message for the upload
|
| 364 |
+
"""
|
| 365 |
+
from huggingface_hub import create_repo, HfApi
|
| 366 |
+
import tempfile
|
| 367 |
+
import shutil
|
| 368 |
+
|
| 369 |
+
print(f"Pushing model to Hugging Face Hub: {repo_id}")
|
| 370 |
+
|
| 371 |
+
# Create repository
|
| 372 |
+
try:
|
| 373 |
+
create_repo(
|
| 374 |
+
repo_id=repo_id,
|
| 375 |
+
token=token,
|
| 376 |
+
private=private,
|
| 377 |
+
exist_ok=True,
|
| 378 |
+
repo_type="model",
|
| 379 |
+
)
|
| 380 |
+
print("✓ Repository created/verified")
|
| 381 |
+
except Exception as e:
|
| 382 |
+
print(f"Error creating repository: {e}")
|
| 383 |
+
raise
|
| 384 |
+
|
| 385 |
+
# Save to temporary directory
|
| 386 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 387 |
+
print(f"Saving model to temporary directory: {tmp_dir}")
|
| 388 |
+
self.save_pretrained(tmp_dir)
|
| 389 |
+
|
| 390 |
+
# Copy config file
|
| 391 |
+
config_src = Path(self.config_path)
|
| 392 |
+
if config_src.exists():
|
| 393 |
+
config_dst = Path(tmp_dir) / "config.yaml"
|
| 394 |
+
shutil.copy2(config_src, config_dst)
|
| 395 |
+
print("✓ Config copied")
|
| 396 |
+
|
| 397 |
+
# Upload to Hub
|
| 398 |
+
api = HfApi()
|
| 399 |
+
try:
|
| 400 |
+
api.upload_folder(
|
| 401 |
+
folder_path=tmp_dir,
|
| 402 |
+
repo_id=repo_id,
|
| 403 |
+
token=token,
|
| 404 |
+
repo_type="model",
|
| 405 |
+
commit_message=commit_message,
|
| 406 |
+
)
|
| 407 |
+
print("✓ Model uploaded successfully!")
|
| 408 |
+
print(f"\n📦 View your model at:")
|
| 409 |
+
print(f"https://huggingface.co/{repo_id}")
|
| 410 |
+
except Exception as e:
|
| 411 |
+
print(f"Error uploading to Hub: {e}")
|
| 412 |
+
raise
|