sino
commited on
Commit
·
0b6f771
1
Parent(s):
a1edc98
Update modeling_maelm.py
Browse files- modeling_maelm.py +13 -13
modeling_maelm.py
CHANGED
|
@@ -192,9 +192,9 @@ class MAEForCausalLM(PreTrainedModel):
|
|
| 192 |
if bk_name == 'MAEViT':
|
| 193 |
ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
|
| 194 |
self.backbone = MAEViT(**backbone)
|
| 195 |
-
if ckpt_path is not None:
|
| 196 |
-
|
| 197 |
-
|
| 198 |
|
| 199 |
elif bk_name == 'HTSAT':
|
| 200 |
ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
|
|
@@ -239,16 +239,16 @@ class MAEForCausalLM(PreTrainedModel):
|
|
| 239 |
# float32 --> bfloat16
|
| 240 |
for p in self.parameters():
|
| 241 |
p.data = p.data.to(torch.bfloat16)
|
| 242 |
-
if config.resume_from_checkpoint is not None:
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
elif config.resume_from_pth is not None:
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
|
| 253 |
if False:
|
| 254 |
self.patch_llm()
|
|
|
|
| 192 |
if bk_name == 'MAEViT':
|
| 193 |
ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
|
| 194 |
self.backbone = MAEViT(**backbone)
|
| 195 |
+
#if ckpt_path is not None:
|
| 196 |
+
# ckpt = torch.load( ckpt_path,'cpu')
|
| 197 |
+
# self.backbone.load_state_dict(ckpt['state_dict'])
|
| 198 |
|
| 199 |
elif bk_name == 'HTSAT':
|
| 200 |
ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
|
|
|
|
| 239 |
# float32 --> bfloat16
|
| 240 |
for p in self.parameters():
|
| 241 |
p.data = p.data.to(torch.bfloat16)
|
| 242 |
+
#if config.resume_from_checkpoint is not None:
|
| 243 |
+
# drain_loader = True
|
| 244 |
+
# accelerator.load_state(config.resume_from_checkpoint, load_module_strict=False)
|
| 245 |
+
# # start_epoch, start_step, all_step = [int(_.split('_')[1]) for _ in args.resume_from_checkpoint.split('/')[-2].split('-')]
|
| 246 |
+
#elif config.resume_from_pth is not None:
|
| 247 |
+
# print(f'###########loading##########{config.resume_from_pth}###########loading##########')
|
| 248 |
+
# ckpt = torch.load(config.resume_from_pth, map_location='cpu')
|
| 249 |
+
# ckpt_copy = {k[7:]: v for k, v in ckpt.items()}
|
| 250 |
+
# self.load_state_dict(ckpt_copy, strict=False)
|
| 251 |
+
# print(f'###########loaded##########{config.resume_from_pth}###########loaded##########')
|
| 252 |
|
| 253 |
if False:
|
| 254 |
self.patch_llm()
|