ModerRAS commited on
Commit
6a5f135
·
1 Parent(s): c705a32

Support architecture override fine-tuning

Browse files
Files changed (1) hide show
  1. anifilebert/train.py +58 -0
anifilebert/train.py CHANGED
@@ -1264,6 +1264,45 @@ def build_vocab_from_data(data: List[Dict], tokenizer: AnimeTokenizer, vocab_pat
1264
  json.dump(tokenizer.get_vocab(), f, ensure_ascii=False, indent=2)
1265
 
1266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1267
  def main():
1268
  args = parse_args()
1269
  config = Config()
@@ -1434,6 +1473,25 @@ def main():
1434
  model.config.id2label = config.id2label
1435
  model.config.label2id = config.label2id
1436
  model.config.label_schema_version = config.label_schema_version
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1437
  else:
1438
  print("Creating model...")
1439
  selected_model_head = "linear" if args.model_head == "auto" else args.model_head
 
1264
  json.dump(tokenizer.get_vocab(), f, ensure_ascii=False, indent=2)
1265
 
1266
 
1267
+ def architecture_matches_config(model, config: Config) -> bool:
1268
+ model_config = model.config
1269
+ return (
1270
+ int(getattr(model_config, "hidden_size", -1)) == int(config.hidden_size)
1271
+ and int(getattr(model_config, "num_hidden_layers", -1)) == int(config.num_hidden_layers)
1272
+ and int(getattr(model_config, "num_attention_heads", -1)) == int(config.num_attention_heads)
1273
+ and int(getattr(model_config, "intermediate_size", -1)) == int(config.intermediate_size)
1274
+ and int(getattr(model_config, "max_position_embeddings", -1)) >= int(config.max_position_embeddings)
1275
+ )
1276
+
1277
+
1278
+ def rebuild_model_with_architecture_overrides(source_model, config: Config, model_head: str):
1279
+ """Create a target architecture and copy same-shaped tensors from source."""
1280
+ target_model = create_model(config, model_head=model_head)
1281
+ source_state = source_model.state_dict()
1282
+ target_state = target_model.state_dict()
1283
+ copied_tensors = 0
1284
+ copied_parameters = 0
1285
+ skipped = []
1286
+
1287
+ with torch.no_grad():
1288
+ for name, target_tensor in target_state.items():
1289
+ source_tensor = source_state.get(name)
1290
+ if source_tensor is None or tuple(source_tensor.shape) != tuple(target_tensor.shape):
1291
+ skipped.append(name)
1292
+ continue
1293
+ target_tensor.copy_(source_tensor.to(device=target_tensor.device, dtype=target_tensor.dtype))
1294
+ copied_tensors += 1
1295
+ copied_parameters += target_tensor.numel()
1296
+
1297
+ target_model.load_state_dict(target_state)
1298
+ return target_model, {
1299
+ "copied_tensors": copied_tensors,
1300
+ "copied_parameters": copied_parameters,
1301
+ "skipped_tensors": len(skipped),
1302
+ "skipped_examples": skipped[:10],
1303
+ }
1304
+
1305
+
1306
  def main():
1307
  args = parse_args()
1308
  config = Config()
 
1473
  model.config.id2label = config.id2label
1474
  model.config.label2id = config.label2id
1475
  model.config.label_schema_version = config.label_schema_version
1476
+ if not architecture_matches_config(model, config):
1477
+ print(
1478
+ " Rebuilding model for architecture overrides: "
1479
+ f"layers={config.num_hidden_layers}, heads={config.num_attention_heads}, "
1480
+ f"hidden={config.hidden_size}, intermediate={config.intermediate_size}"
1481
+ )
1482
+ model, architecture_copy = rebuild_model_with_architecture_overrides(
1483
+ source_model=model,
1484
+ config=config,
1485
+ model_head=selected_model_head,
1486
+ )
1487
+ print(
1488
+ " Copied compatible tensors into override architecture: "
1489
+ f"{architecture_copy['copied_tensors']} tensors, "
1490
+ f"{architecture_copy['copied_parameters']:,} params; "
1491
+ f"skipped {architecture_copy['skipped_tensors']} tensors"
1492
+ )
1493
+ if architecture_copy["skipped_examples"]:
1494
+ print(f" Skipped tensor examples: {architecture_copy['skipped_examples']}")
1495
  else:
1496
  print("Creating model...")
1497
  selected_model_head = "linear" if args.model_head == "auto" else args.model_head