Thanh-Lam commited on
Commit
c2b3996
·
1 Parent(s): 419fa8b

Add workaround for deprecated plda parameter in model config, downgrade to pyannote.audio 3.0.1

Browse files
Files changed (2) hide show
  1. requirements.txt +2 -1
  2. src/models.py +38 -2
requirements.txt CHANGED
@@ -1,7 +1,8 @@
1
- pyannote.audio==3.1.1
2
  torch==2.1.2
3
  torchaudio==2.1.2
4
  gradio==4.19.2
5
  huggingface_hub==0.20.3
6
  yt-dlp
7
  numpy<2.0
 
 
1
+ pyannote.audio==3.0.1
2
  torch==2.1.2
3
  torchaudio==2.1.2
4
  gradio==4.19.2
5
  huggingface_hub==0.20.3
6
  yt-dlp
7
  numpy<2.0
8
+ PyYAML
src/models.py CHANGED
@@ -31,11 +31,47 @@ class DiarizationEngine:
31
  segmentation_params: Optional[Dict[str, float]] = None,
32
  clustering_params: Optional[Dict[str, float]] = None,
33
  ) -> None:
 
34
  self.device = self._resolve_device(device)
35
  auth_token = read_hf_token(token, key_path)
36
 
37
- # pyannote.audio 3.1.x uses 'use_auth_token' parameter
38
- pipeline = Pipeline.from_pretrained(model_id, use_auth_token=auth_token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  params = pipeline.parameters()
41
  # Giảm phân mảnh: chỉ cập nhật các khóa thực sự tồn tại để tránh lỗi.
 
31
  segmentation_params: Optional[Dict[str, float]] = None,
32
  clustering_params: Optional[Dict[str, float]] = None,
33
  ) -> None:
34
+ import sys
35
  self.device = self._resolve_device(device)
36
  auth_token = read_hf_token(token, key_path)
37
 
38
+ # Handle model config with deprecated parameters
39
+ try:
40
+ pipeline = Pipeline.from_pretrained(model_id, use_auth_token=auth_token)
41
+ except TypeError as e:
42
+ if "plda" in str(e):
43
+ print(f"WARNING: Model config contains deprecated 'plda' parameter. Loading with workaround...", file=sys.stderr)
44
+ # Download and patch config
45
+ from huggingface_hub import hf_hub_download
46
+ import yaml
47
+ import tempfile
48
+ from pathlib import Path
49
+
50
+ # Download original config
51
+ config_path = hf_hub_download(repo_id=model_id, filename="config.yaml", token=auth_token)
52
+ with open(config_path, 'r') as f:
53
+ config = yaml.safe_load(f)
54
+
55
+ # Remove deprecated params
56
+ if 'params' in config:
57
+ deprecated_keys = ['plda']
58
+ for key in deprecated_keys:
59
+ if key in config['params']:
60
+ print(f" Removing deprecated parameter: {key}", file=sys.stderr)
61
+ del config['params'][key]
62
+
63
+ # Save patched config to temp file
64
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as tmp:
65
+ yaml.dump(config, tmp)
66
+ tmp_path = tmp.name
67
+
68
+ try:
69
+ # Load pipeline from patched config
70
+ pipeline = Pipeline.from_pretrained(tmp_path, use_auth_token=auth_token)
71
+ finally:
72
+ Path(tmp_path).unlink(missing_ok=True)
73
+ else:
74
+ raise
75
 
76
  params = pipeline.parameters()
77
  # Giảm phân mảnh: chỉ cập nhật các khóa thực sự tồn tại để tránh lỗi.