binaryMao commited on
Commit
5484bb2
·
verified ·
1 Parent(s): ba79116

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -19
app.py CHANGED
@@ -1,5 +1,4 @@
1
 
2
-
3
  import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil
4
  import torch
5
  from huggingface_hub import snapshot_download
@@ -58,30 +57,27 @@ def get_model(name):
58
 
59
  if not nemo_file: raise FileNotFoundError("Fichier .nemo introuvable.")
60
 
61
- # Correctif pour les clés "embedding_model" inattendues et erreur __init__
62
- try:
63
- model = nemo_asr.models.ASRModel.restore_from(
64
- nemo_file,
65
- map_location=torch.device(DEVICE),
66
- strict=False,
67
- override_config_path=None
68
- )
69
- except Exception as e:
70
- print(f"⚠️ Tentative de chargement alternatif : {e}")
71
- # Fallback sans override_config_path
72
- model = nemo_asr.models.ASRModel.restore_from(
73
- nemo_file,
74
- map_location=torch.device(DEVICE),
75
- strict=False
76
- )
77
 
78
- model.to(DEVICE).eval()
 
 
 
 
 
 
 
79
  if DEVICE == "cuda":
80
- model.half()
 
 
 
81
 
82
  _cache[name] = model
83
  return model
84
 
 
85
  # 3. UTILITAIRES
86
  def format_srt_time(sec):
87
  td = time.gmtime(sec)
 
1
 
 
2
  import os, shlex, subprocess, tempfile, traceback, time, glob, gc, shutil
3
  import torch
4
  from huggingface_hub import snapshot_download
 
57
 
58
  if not nemo_file: raise FileNotFoundError("Fichier .nemo introuvable.")
59
 
60
+ # Méthode la plus simple et robuste pour charger le modèle
61
+ print(f"📦 Chargement du modèle depuis : {nemo_file}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # On laisse NeMo gérer tout automatiquement
64
+ model = nemo_asr.models.ASRModel.restore_from(nemo_file)
65
+
66
+ # Déplacement vers le device approprié
67
+ model = model.to(DEVICE)
68
+ model.eval()
69
+
70
+ # Optimisation mémoire pour GPU
71
  if DEVICE == "cuda":
72
+ try:
73
+ model = model.half()
74
+ except Exception as e:
75
+ print(f"⚠️ Impossible de convertir en half precision: {e}")
76
 
77
  _cache[name] = model
78
  return model
79
 
80
+
81
  # 3. UTILITAIRES
82
  def format_srt_time(sec):
83
  td = time.gmtime(sec)