Spaces:
Sleeping
Sleeping
Update modeling_vit5_kg.py
Browse files- modeling_vit5_kg.py +58 -6
modeling_vit5_kg.py
CHANGED
|
@@ -293,14 +293,66 @@ class KGEnhancedViT5(T5ForConditionalGeneration):
|
|
| 293 |
)
|
| 294 |
|
| 295 |
return generated_ids
|
| 296 |
-
|
| 297 |
@classmethod
|
| 298 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 299 |
"""
|
| 300 |
Load model from pretrained path
|
| 301 |
"""
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
)
|
| 294 |
|
| 295 |
return generated_ids
|
| 296 |
+
|
| 297 |
@classmethod
|
| 298 |
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 299 |
"""
|
| 300 |
Load model from pretrained path
|
| 301 |
"""
|
| 302 |
+
from transformers import AutoConfig
|
| 303 |
+
from huggingface_hub import hf_hub_download
|
| 304 |
+
|
| 305 |
+
# Load config
|
| 306 |
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
| 307 |
+
|
| 308 |
+
# Extract custom fields
|
| 309 |
+
use_kg = getattr(config, 'use_kg', True)
|
| 310 |
+
kg_node_features = getattr(config, 'kg_node_features', 300)
|
| 311 |
+
gnn_hidden = getattr(config, 'gnn_hidden', 256)
|
| 312 |
+
gnn_type = getattr(config, 'gnn_type', 'gcn')
|
| 313 |
+
gnn_layers = getattr(config, 'gnn_layers', 2)
|
| 314 |
+
dropout = getattr(config, 'dropout_rate', 0.1)
|
| 315 |
+
|
| 316 |
+
# Create model với đúng architecture
|
| 317 |
+
model = cls(
|
| 318 |
+
config=config,
|
| 319 |
+
kg_node_features=kg_node_features,
|
| 320 |
+
gnn_hidden=gnn_hidden,
|
| 321 |
+
gnn_type=gnn_type,
|
| 322 |
+
gnn_layers=gnn_layers,
|
| 323 |
+
dropout=dropout,
|
| 324 |
+
use_kg=use_kg
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
# Load weights trực tiếp từ checkpoint file
|
| 328 |
+
try:
|
| 329 |
+
# Try safetensors first
|
| 330 |
+
try:
|
| 331 |
+
state_dict_path = hf_hub_download(
|
| 332 |
+
repo_id=pretrained_model_name_or_path,
|
| 333 |
+
filename="model.safetensors"
|
| 334 |
+
)
|
| 335 |
+
from safetensors.torch import load_file
|
| 336 |
+
state_dict = load_file(state_dict_path)
|
| 337 |
+
except:
|
| 338 |
+
# Fallback to pytorch_model.bin
|
| 339 |
+
state_dict_path = hf_hub_download(
|
| 340 |
+
repo_id=pretrained_model_name_or_path,
|
| 341 |
+
filename="pytorch_model.bin"
|
| 342 |
+
)
|
| 343 |
+
state_dict = torch.load(state_dict_path, map_location='cpu')
|
| 344 |
+
|
| 345 |
+
# Load weights với strict=False
|
| 346 |
+
model.load_state_dict(state_dict, strict=False)
|
| 347 |
+
except Exception as e:
|
| 348 |
+
# Fallback: dùng parent's from_pretrained
|
| 349 |
+
print(f"Warning: Could not load weights directly: {e}")
|
| 350 |
+
parent_model = super().from_pretrained(
|
| 351 |
+
pretrained_model_name_or_path,
|
| 352 |
+
*model_args,
|
| 353 |
+
config=config,
|
| 354 |
+
**kwargs
|
| 355 |
+
)
|
| 356 |
+
model.load_state_dict(parent_model.state_dict(), strict=False)
|
| 357 |
+
|
| 358 |
+
return model
|