svp19 commited on
Commit
d5f414f
·
1 Parent(s): 5a6e5b2

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +8 -9
model.py CHANGED
@@ -44,14 +44,13 @@ class INF5Model(PreTrainedModel):
44
 
45
  # Download and load model weights
46
  # safetensors_path = hf_hub_download(config.name_or_path, filename="model.safetensors")
47
- print(f"Loading model weights from {safetensors_path} (safetensors)...")
48
- state_dict = load_file(safetensors_path, device=str(device))
49
 
50
- # Download vocab.txt from HF Hub
51
- vocab_path = hf_hub_download(config.name_or_path, filename="checkpoints/vocab.txt")
52
 
53
- self.ema_model = torch.compile(
54
- load_model(
55
  DiT,
56
  dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
57
  None, # Skip loading from file, as we load state_dict directly
@@ -59,10 +58,10 @@ class INF5Model(PreTrainedModel):
59
  vocab_file=vocab_path,
60
  device=str(device)
61
  )
62
- )
63
 
64
- # Load state dict into model
65
- self.ema_model.load_state_dict(state_dict, strict=False)
66
 
67
  @classmethod
68
  def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
 
44
 
45
  # Download and load model weights
46
  # safetensors_path = hf_hub_download(config.name_or_path, filename="model.safetensors")
47
+ # print(f"Loading model weights from {safetensors_path} (safetensors)...")
48
+ # state_dict = load_file(safetensors_path, device=str(device))
49
 
50
+ # # Download vocab.txt from HF Hub
51
+ # vocab_path = hf_hub_download(config.name_or_path, filename="checkpoints/vocab.txt")
52
 
53
+ self.ema_model = load_model(
 
54
  DiT,
55
  dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
56
  None, # Skip loading from file, as we load state_dict directly
 
58
  vocab_file=vocab_path,
59
  device=str(device)
60
  )
61
+
62
 
63
+ # # Load state dict into model
64
+ # self.ema_model.load_state_dict(state_dict, strict=False)
65
 
66
  @classmethod
67
  def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):