Enzo8930302 commited on
Commit
9e92643
·
verified ·
1 Parent(s): 28d98f8

Upload bytedream/generator.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. bytedream/generator.py +97 -2
bytedream/generator.py CHANGED
@@ -24,6 +24,7 @@ class ByteDreamGenerator:
24
  config_path: str = "config.yaml",
25
  device: str = "cpu",
26
  use_safetensors: bool = True,
 
27
  ):
28
  """
29
  Initialize Byte Dream generator
@@ -33,17 +34,19 @@ class ByteDreamGenerator:
33
  config_path: Path to configuration file
34
  device: Device to run on (default: cpu)
35
  use_safetensors: Use safetensors format if available
 
36
  """
37
  self.device = device
38
  self.config_path = config_path
39
  self.use_safetensors = use_safetensors
 
40
 
41
  # Load configuration
42
  self.config = self._load_config(config_path)
43
 
44
  # Initialize components
45
  print("Initializing Byte Dream Generator...")
46
- self.pipeline = self._initialize_pipeline(model_path)
47
 
48
  # Optimize for CPU
49
  self._optimize_for_cpu()
@@ -92,12 +95,27 @@ class ByteDreamGenerator:
92
  }
93
  }
94
 
95
- def _initialize_pipeline(self, model_path: Optional[str]):
96
  """Initialize the generation pipeline"""
97
  from bytedream.model import create_unet, create_vae, create_text_encoder
98
  from bytedream.scheduler import create_scheduler
99
  from bytedream.pipeline import ByteDreamPipeline
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  # Create model components
102
  print("Creating UNet...")
103
  unet = create_unet(self.config)
@@ -315,3 +333,80 @@ class ByteDreamGenerator:
315
  if torch.cuda.is_available():
316
  torch.cuda.empty_cache()
317
  print("Memory cleared")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  config_path: str = "config.yaml",
25
  device: str = "cpu",
26
  use_safetensors: bool = True,
27
+ hf_repo_id: Optional[str] = None,
28
  ):
29
  """
30
  Initialize Byte Dream generator
 
34
  config_path: Path to configuration file
35
  device: Device to run on (default: cpu)
36
  use_safetensors: Use safetensors format if available
37
+ hf_repo_id: Hugging Face repository ID (e.g., "username/repo")
38
  """
39
  self.device = device
40
  self.config_path = config_path
41
  self.use_safetensors = use_safetensors
42
+ self.hf_repo_id = hf_repo_id
43
 
44
  # Load configuration
45
  self.config = self._load_config(config_path)
46
 
47
  # Initialize components
48
  print("Initializing Byte Dream Generator...")
49
+ self.pipeline = self._initialize_pipeline(model_path, hf_repo_id)
50
 
51
  # Optimize for CPU
52
  self._optimize_for_cpu()
 
95
  }
96
  }
97
 
98
+ def _initialize_pipeline(self, model_path: Optional[str], hf_repo_id: Optional[str] = None):
99
  """Initialize the generation pipeline"""
100
  from bytedream.model import create_unet, create_vae, create_text_encoder
101
  from bytedream.scheduler import create_scheduler
102
  from bytedream.pipeline import ByteDreamPipeline
103
 
104
+ # If HF repo ID is provided, try to load from Hugging Face
105
+ if hf_repo_id is not None:
106
+ print(f"Loading model from Hugging Face: {hf_repo_id}...")
107
+ try:
108
+ from bytedream.pipeline import ByteDreamPipeline
109
+ pipeline = ByteDreamPipeline.from_pretrained(
110
+ hf_repo_id,
111
+ device=self.device,
112
+ dtype=torch.float32,
113
+ )
114
+ return pipeline
115
+ except Exception as e:
116
+ print(f"Error loading from Hugging Face: {e}")
117
+ print("Falling back to local model...")
118
+
119
  # Create model components
120
  print("Creating UNet...")
121
  unet = create_unet(self.config)
 
333
  if torch.cuda.is_available():
334
  torch.cuda.empty_cache()
335
  print("Memory cleared")
336
+
337
+ def save_pretrained(self, save_directory: str):
338
+ """
339
+ Save model to directory for Hugging Face upload
340
+
341
+ Args:
342
+ save_directory: Directory path to save models
343
+ """
344
+ if self.pipeline is None:
345
+ raise ValueError("No pipeline initialized. Cannot save.")
346
+
347
+ return self.pipeline.save_pretrained(save_directory)
348
+
349
+ def push_to_hub(
350
+ self,
351
+ repo_id: str,
352
+ token: Optional[str] = None,
353
+ private: bool = False,
354
+ commit_message: str = "Upload Byte Dream model",
355
+ ):
356
+ """
357
+ Push model to Hugging Face Hub
358
+
359
+ Args:
360
+ repo_id: Repository ID (username/model-name)
361
+ token: Hugging Face API token
362
+ private: Whether to make repository private
363
+ commit_message: Commit message for the upload
364
+ """
365
+ from huggingface_hub import create_repo, HfApi
366
+ import tempfile
367
+ import shutil
368
+
369
+ print(f"Pushing model to Hugging Face Hub: {repo_id}")
370
+
371
+ # Create repository
372
+ try:
373
+ create_repo(
374
+ repo_id=repo_id,
375
+ token=token,
376
+ private=private,
377
+ exist_ok=True,
378
+ repo_type="model",
379
+ )
380
+ print("✓ Repository created/verified")
381
+ except Exception as e:
382
+ print(f"Error creating repository: {e}")
383
+ raise
384
+
385
+ # Save to temporary directory
386
+ with tempfile.TemporaryDirectory() as tmp_dir:
387
+ print(f"Saving model to temporary directory: {tmp_dir}")
388
+ self.save_pretrained(tmp_dir)
389
+
390
+ # Copy config file
391
+ config_src = Path(self.config_path)
392
+ if config_src.exists():
393
+ config_dst = Path(tmp_dir) / "config.yaml"
394
+ shutil.copy2(config_src, config_dst)
395
+ print("✓ Config copied")
396
+
397
+ # Upload to Hub
398
+ api = HfApi()
399
+ try:
400
+ api.upload_folder(
401
+ folder_path=tmp_dir,
402
+ repo_id=repo_id,
403
+ token=token,
404
+ repo_type="model",
405
+ commit_message=commit_message,
406
+ )
407
+ print("✓ Model uploaded successfully!")
408
+ print(f"\n📦 View your model at:")
409
+ print(f"https://huggingface.co/{repo_id}")
410
+ except Exception as e:
411
+ print(f"Error uploading to Hub: {e}")
412
+ raise