Spaces:
Build error
Build error
Add workaround for deprecated plda parameter in model config, downgrade to pyannote.audio 3.0.1
Browse files- requirements.txt +2 -1
- src/models.py +38 -2
requirements.txt
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
-
pyannote.audio==3.
|
| 2 |
torch==2.1.2
|
| 3 |
torchaudio==2.1.2
|
| 4 |
gradio==4.19.2
|
| 5 |
huggingface_hub==0.20.3
|
| 6 |
yt-dlp
|
| 7 |
numpy<2.0
|
|
|
|
|
|
| 1 |
+
pyannote.audio==3.0.1
|
| 2 |
torch==2.1.2
|
| 3 |
torchaudio==2.1.2
|
| 4 |
gradio==4.19.2
|
| 5 |
huggingface_hub==0.20.3
|
| 6 |
yt-dlp
|
| 7 |
numpy<2.0
|
| 8 |
+
PyYAML
|
src/models.py
CHANGED
|
@@ -31,11 +31,47 @@ class DiarizationEngine:
|
|
| 31 |
segmentation_params: Optional[Dict[str, float]] = None,
|
| 32 |
clustering_params: Optional[Dict[str, float]] = None,
|
| 33 |
) -> None:
|
|
|
|
| 34 |
self.device = self._resolve_device(device)
|
| 35 |
auth_token = read_hf_token(token, key_path)
|
| 36 |
|
| 37 |
-
#
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
params = pipeline.parameters()
|
| 41 |
# Giảm phân mảnh: chỉ cập nhật các khóa thực sự tồn tại để tránh lỗi.
|
|
|
|
| 31 |
segmentation_params: Optional[Dict[str, float]] = None,
|
| 32 |
clustering_params: Optional[Dict[str, float]] = None,
|
| 33 |
) -> None:
|
| 34 |
+
import sys
|
| 35 |
self.device = self._resolve_device(device)
|
| 36 |
auth_token = read_hf_token(token, key_path)
|
| 37 |
|
| 38 |
+
# Handle model config with deprecated parameters
|
| 39 |
+
try:
|
| 40 |
+
pipeline = Pipeline.from_pretrained(model_id, use_auth_token=auth_token)
|
| 41 |
+
except TypeError as e:
|
| 42 |
+
if "plda" in str(e):
|
| 43 |
+
print(f"WARNING: Model config contains deprecated 'plda' parameter. Loading with workaround...", file=sys.stderr)
|
| 44 |
+
# Download and patch config
|
| 45 |
+
from huggingface_hub import hf_hub_download
|
| 46 |
+
import yaml
|
| 47 |
+
import tempfile
|
| 48 |
+
from pathlib import Path
|
| 49 |
+
|
| 50 |
+
# Download original config
|
| 51 |
+
config_path = hf_hub_download(repo_id=model_id, filename="config.yaml", token=auth_token)
|
| 52 |
+
with open(config_path, 'r') as f:
|
| 53 |
+
config = yaml.safe_load(f)
|
| 54 |
+
|
| 55 |
+
# Remove deprecated params
|
| 56 |
+
if 'params' in config:
|
| 57 |
+
deprecated_keys = ['plda']
|
| 58 |
+
for key in deprecated_keys:
|
| 59 |
+
if key in config['params']:
|
| 60 |
+
print(f" Removing deprecated parameter: {key}", file=sys.stderr)
|
| 61 |
+
del config['params'][key]
|
| 62 |
+
|
| 63 |
+
# Save patched config to temp file
|
| 64 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.yaml', delete=False) as tmp:
|
| 65 |
+
yaml.dump(config, tmp)
|
| 66 |
+
tmp_path = tmp.name
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
# Load pipeline from patched config
|
| 70 |
+
pipeline = Pipeline.from_pretrained(tmp_path, use_auth_token=auth_token)
|
| 71 |
+
finally:
|
| 72 |
+
Path(tmp_path).unlink(missing_ok=True)
|
| 73 |
+
else:
|
| 74 |
+
raise
|
| 75 |
|
| 76 |
params = pipeline.parameters()
|
| 77 |
# Giảm phân mảnh: chỉ cập nhật các khóa thực sự tồn tại để tránh lỗi.
|