Update model.py
Browse files
model.py
CHANGED
|
@@ -44,14 +44,13 @@ class INF5Model(PreTrainedModel):
|
|
| 44 |
|
| 45 |
# Download and load model weights
|
| 46 |
# safetensors_path = hf_hub_download(config.name_or_path, filename="model.safetensors")
|
| 47 |
-
print(f"Loading model weights from {safetensors_path} (safetensors)...")
|
| 48 |
-
state_dict = load_file(safetensors_path, device=str(device))
|
| 49 |
|
| 50 |
-
# Download vocab.txt from HF Hub
|
| 51 |
-
vocab_path = hf_hub_download(config.name_or_path, filename="checkpoints/vocab.txt")
|
| 52 |
|
| 53 |
-
self.ema_model =
|
| 54 |
-
load_model(
|
| 55 |
DiT,
|
| 56 |
dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
|
| 57 |
None, # Skip loading from file, as we load state_dict directly
|
|
@@ -59,10 +58,10 @@ class INF5Model(PreTrainedModel):
|
|
| 59 |
vocab_file=vocab_path,
|
| 60 |
device=str(device)
|
| 61 |
)
|
| 62 |
-
|
| 63 |
|
| 64 |
-
# Load state dict into model
|
| 65 |
-
self.ema_model.load_state_dict(state_dict, strict=False)
|
| 66 |
|
| 67 |
@classmethod
|
| 68 |
def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
|
|
|
|
| 44 |
|
| 45 |
# Download and load model weights
|
| 46 |
# safetensors_path = hf_hub_download(config.name_or_path, filename="model.safetensors")
|
| 47 |
+
# print(f"Loading model weights from {safetensors_path} (safetensors)...")
|
| 48 |
+
# state_dict = load_file(safetensors_path, device=str(device))
|
| 49 |
|
| 50 |
+
# # Download vocab.txt from HF Hub
|
| 51 |
+
# vocab_path = hf_hub_download(config.name_or_path, filename="checkpoints/vocab.txt")
|
| 52 |
|
| 53 |
+
self.ema_model = load_model(
|
|
|
|
| 54 |
DiT,
|
| 55 |
dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
|
| 56 |
None, # Skip loading from file, as we load state_dict directly
|
|
|
|
| 58 |
vocab_file=vocab_path,
|
| 59 |
device=str(device)
|
| 60 |
)
|
| 61 |
+
|
| 62 |
|
| 63 |
+
# # Load state dict into model
|
| 64 |
+
# self.ema_model.load_state_dict(state_dict, strict=False)
|
| 65 |
|
| 66 |
@classmethod
|
| 67 |
def from_pretrained(cls, model_name_or_path, *model_args, **kwargs):
|