nhatvu205 commited on
Commit
28faaa5
·
verified ·
1 Parent(s): 7eecba3

Update modeling_vit5_kg.py

Browse files
Files changed (1) hide show
  1. modeling_vit5_kg.py +58 -6
modeling_vit5_kg.py CHANGED
@@ -293,14 +293,66 @@ class KGEnhancedViT5(T5ForConditionalGeneration):
293
  )
294
 
295
  return generated_ids
296
-
297
  @classmethod
298
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
299
  """
300
  Load model from pretrained path
301
  """
302
- return super().from_pretrained(
303
- pretrained_model_name_or_path,
304
- *model_args,
305
- **kwargs
306
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  )
294
 
295
  return generated_ids
296
+
297
  @classmethod
298
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
299
  """
300
  Load model from pretrained path
301
  """
302
+ from transformers import AutoConfig
303
+ from huggingface_hub import hf_hub_download
304
+
305
+ # Load config
306
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
307
+
308
+ # Extract custom fields
309
+ use_kg = getattr(config, 'use_kg', True)
310
+ kg_node_features = getattr(config, 'kg_node_features', 300)
311
+ gnn_hidden = getattr(config, 'gnn_hidden', 256)
312
+ gnn_type = getattr(config, 'gnn_type', 'gcn')
313
+ gnn_layers = getattr(config, 'gnn_layers', 2)
314
+ dropout = getattr(config, 'dropout_rate', 0.1)
315
+
316
+ # Create model với đúng architecture
317
+ model = cls(
318
+ config=config,
319
+ kg_node_features=kg_node_features,
320
+ gnn_hidden=gnn_hidden,
321
+ gnn_type=gnn_type,
322
+ gnn_layers=gnn_layers,
323
+ dropout=dropout,
324
+ use_kg=use_kg
325
+ )
326
+
327
+ # Load weights trực tiếp từ checkpoint file
328
+ try:
329
+ # Try safetensors first
330
+ try:
331
+ state_dict_path = hf_hub_download(
332
+ repo_id=pretrained_model_name_or_path,
333
+ filename="model.safetensors"
334
+ )
335
+ from safetensors.torch import load_file
336
+ state_dict = load_file(state_dict_path)
337
+ except:
338
+ # Fallback to pytorch_model.bin
339
+ state_dict_path = hf_hub_download(
340
+ repo_id=pretrained_model_name_or_path,
341
+ filename="pytorch_model.bin"
342
+ )
343
+ state_dict = torch.load(state_dict_path, map_location='cpu')
344
+
345
+ # Load weights với strict=False
346
+ model.load_state_dict(state_dict, strict=False)
347
+ except Exception as e:
348
+ # Fallback: dùng parent's from_pretrained
349
+ print(f"Warning: Could not load weights directly: {e}")
350
+ parent_model = super().from_pretrained(
351
+ pretrained_model_name_or_path,
352
+ *model_args,
353
+ config=config,
354
+ **kwargs
355
+ )
356
+ model.load_state_dict(parent_model.state_dict(), strict=False)
357
+
358
+ return model