Thanh-Lam commited on
Commit
b0cfc60
·
1 Parent(s): a74861c

Enhance model loading: auto-detect head_hidden_dim from checkpoint and streamline checkpoint loading process

Browse files
Files changed (1) hide show
  1. app.py +25 -15
app.py CHANGED
@@ -84,30 +84,40 @@ class MultiModelProfiler:
84
  # Load model - use MultiTaskSpeakerModel
85
  from src.models import MultiTaskSpeakerModel
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  model = MultiTaskSpeakerModel(
88
  model_name=encoder_name,
89
  num_genders=2,
90
  num_dialects=3,
91
  dropout=0.1,
 
92
  freeze_encoder=True
93
  )
94
 
95
- # Load checkpoint from safetensors
96
- checkpoint_path = model_path / "model.safetensors"
97
- if checkpoint_path.exists():
98
- state_dict = load_safetensors(str(checkpoint_path))
99
  model.load_state_dict(state_dict)
100
- print(f"Loaded checkpoint: {checkpoint_path}")
101
- else:
102
- # Try loading from .pt file
103
- pt_path = model_path / "best_model.pt"
104
- if pt_path.exists():
105
- checkpoint = torch.load(pt_path, map_location=self.device, weights_only=False)
106
- if "model_state_dict" in checkpoint:
107
- model.load_state_dict(checkpoint["model_state_dict"])
108
- else:
109
- model.load_state_dict(checkpoint)
110
- print(f"Loaded checkpoint: {pt_path}")
111
 
112
  model.to(self.device)
113
  model.eval()
 
84
  # Load model - use MultiTaskSpeakerModel
85
  from src.models import MultiTaskSpeakerModel
86
 
87
+ # Load checkpoint first to detect head_hidden_dim
88
+ checkpoint_path = model_path / "model.safetensors"
89
+ pt_path = model_path / "best_model.pt"
90
+ state_dict = None
91
+
92
+ if checkpoint_path.exists():
93
+ state_dict = load_safetensors(str(checkpoint_path))
94
+ elif pt_path.exists():
95
+ checkpoint = torch.load(pt_path, map_location=self.device, weights_only=False)
96
+ if "model_state_dict" in checkpoint:
97
+ state_dict = checkpoint["model_state_dict"]
98
+ else:
99
+ state_dict = checkpoint
100
+
101
+ # Auto-detect head_hidden_dim from checkpoint
102
+ head_hidden_dim = 256 # default
103
+ if state_dict is not None and "gender_head.0.weight" in state_dict:
104
+ # gender_head.0.weight has shape [head_hidden_dim, hidden_size]
105
+ head_hidden_dim = state_dict["gender_head.0.weight"].shape[0]
106
+ print(f"Detected head_hidden_dim: {head_hidden_dim}")
107
+
108
  model = MultiTaskSpeakerModel(
109
  model_name=encoder_name,
110
  num_genders=2,
111
  num_dialects=3,
112
  dropout=0.1,
113
+ head_hidden_dim=head_hidden_dim,
114
  freeze_encoder=True
115
  )
116
 
117
+ # Load checkpoint weights
118
+ if state_dict is not None:
 
 
119
  model.load_state_dict(state_dict)
120
+ print(f"Loaded checkpoint: {checkpoint_path if checkpoint_path.exists() else pt_path}")
 
 
 
 
 
 
 
 
 
 
121
 
122
  model.to(self.device)
123
  model.eval()