FFomy commited on
Commit
399aaa2
·
verified ·
1 Parent(s): 3c53f92

final try

Browse files
Files changed (1) hide show
  1. app.py +45 -45
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import spaces
3
- # only debug for hf now
4
  REPO_TYPE = "hf"
5
  if REPO_TYPE not in ["hf", "ms"]:
6
  raise ValueError("REPO_TYPE must be either 'hf' for Hugging Face or 'ms' for ModelScope.")
@@ -13,48 +13,40 @@ else:
13
 
14
 
15
  # 1. 定义本地路径和远程仓库ID
16
- MODEL_CACHE_DIR = "./models"
17
- FUN_ASR_NANO_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "Fun-ASR-Nano")
18
- SENSE_VOICE_SMALL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "SenseVoiceSmall")
19
- VAD_MODEL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "fsmn-vad")
20
-
21
- # 创建模型缓存目录
22
- os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
23
 
24
- # 设置ModelScope环境变量以使用本地缓存
25
- os.environ['MODELSCOPE_CACHE'] = MODEL_CACHE_DIR
26
- # 禁用远程下载,强制使用本地模型(可选,如果想要确保只使用本地模型)
27
- # os.environ['MODELSCOPE_DISABLE_REMOTE'] = '1'
28
-
29
- print(f"ModelScope缓存目录设置为: {MODEL_CACHE_DIR}")
30
 
31
  if REPO_TYPE == "ms":
32
  FUN_ASR_NANO_REPO_ID = "FunAudioLLM/Fun-ASR-Nano-2512"
33
  SENSE_VOICE_SMALL_REPO_ID = "iic/SenseVoiceSmall"
34
- VAD_MODEL_REPO_ID = "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
35
  else:
36
  FUN_ASR_NANO_REPO_ID = "FunAudioLLM/Fun-ASR-Nano-2512"
37
  SENSE_VOICE_SMALL_REPO_ID = "FunAudioLLM/SenseVoiceSmall"
38
- VAD_MODEL_REPO_ID = "funasr/fsmn-vad"
39
 
40
  # 2. 检查本地是否存在,不存在则下载
41
- def download_model_if_not_exists(repo_id, local_path, model_name):
42
- """如果本地模型不存,则下载模型"""
43
- if not os.path.exists(local_path):
44
- print(f"正在下载模型 {model_name} 到 {local_path} ...")
45
- snapshot_download(
46
- repo_id=repo_id,
47
- local_dir=local_path,
48
- ignore_patterns=["*.onnx"], # 如果你不需要onnx文件,可以过滤掉以节省时间和空间
49
- )
50
- print(f"{model_name} 模型下载完毕!")
51
- else:
52
- print(f"检测到本地 {model_name} 模型文件,跳过下载。")
53
 
54
- # 下载所有需要的模型
55
- download_model_if_not_exists(FUN_ASR_NANO_REPO_ID, FUN_ASR_NANO_LOCAL_PATH, "Fun-ASR-Nano")
56
- download_model_if_not_exists(SENSE_VOICE_SMALL_REPO_ID, SENSE_VOICE_SMALL_LOCAL_PATH, "SenseVoiceSmall")
57
- download_model_if_not_exists(VAD_MODEL_REPO_ID, VAD_MODEL_LOCAL_PATH, "VAD Model")
 
 
 
 
 
 
 
58
 
59
 
60
 
@@ -74,13 +66,13 @@ import importlib
74
  from funasr import AutoModel
75
  from funasr.utils.postprocess_utils import rich_transcription_postprocess
76
 
77
- # Model configurations for local deployment
78
  FUN_ASR_NANO_MODEL_PATH_LIST = [
79
- FUN_ASR_NANO_LOCAL_PATH, # local path
80
  ]
81
 
82
  SENSEVOICE_MODEL_PATH_LIST = [
83
- SENSE_VOICE_SMALL_LOCAL_PATH, # local path
84
  ]
85
 
86
  class LogCapture(io.StringIO):
@@ -101,8 +93,8 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
101
 
102
 
103
  # Check for CUDA availability
104
- # device = "cuda:0" if torch.cuda.is_available() else "cpu"
105
- # logging.info(f"Using device: {device}")
106
 
107
  def download_audio(url, method_choice, proxy_url, proxy_username, proxy_password):
108
  """
@@ -414,7 +406,7 @@ def get_model_options(pipeline_type):
414
  # Dictionary to store loaded models
415
  loaded_models = {}
416
 
417
- @spaces.GPU(duration=40)
418
  def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_password, pipeline_type, model_id, download_method, start_time=None, end_time=None, verbose=False):
419
  """
420
  Transcribes audio from a given source using SenseVoice.
@@ -435,9 +427,6 @@ def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_pa
435
  Yields:
436
  Tuple[str, str, str or None]: Metrics and messages, transcription text, path to transcription file.
437
  """
438
- current_device = "cuda:0" if torch.cuda.is_available() else "cpu"
439
- device = current_device
440
- logging.info(f"Using device: {device}")
441
  try:
442
  if verbose:
443
  logging.getLogger().setLevel(logging.INFO)
@@ -489,6 +478,7 @@ def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_pa
489
 
490
  # Model caching
491
  model_key = (pipeline_type, model_id)
 
492
  if model_key in loaded_models:
493
  model = loaded_models[model_key]
494
  logging.info("Loaded model from cache")
@@ -498,9 +488,9 @@ def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_pa
498
  model=model_id,
499
  trust_remote_code=True,
500
  remote_code=f"./Fun-ASR/model.py",
501
- vad_model=VAD_MODEL_LOCAL_PATH, # Use local VAD model path
502
  vad_kwargs={"max_single_segment_time": 30000},
503
- device=device,
504
  disable_update=True,
505
  hub='ms',
506
  )
@@ -508,9 +498,9 @@ def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_pa
508
  model = AutoModel(
509
  model=model_id,
510
  trust_remote_code=False,
511
- vad_model=VAD_MODEL_LOCAL_PATH, # Use local VAD model path
512
  vad_kwargs={"max_single_segment_time": 30000},
513
- device=device,
514
  disable_update=True,
515
  hub='ms',
516
  )
@@ -520,6 +510,14 @@ def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_pa
520
  yield verbose_messages + error_msg, "", None
521
  return
522
  loaded_models[model_key] = model
 
 
 
 
 
 
 
 
523
 
524
  # Perform the transcription
525
  start_time_perf = time.time()
@@ -547,6 +545,8 @@ def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_pa
547
  merge_vad=True,
548
  merge_length_s=15,
549
  )
 
 
550
 
551
  transcription = rich_transcription_postprocess(res[0]["text"])
552
  end_time_perf = time.time()
 
1
  import os
2
  import spaces
3
+
4
  REPO_TYPE = "hf"
5
  if REPO_TYPE not in ["hf", "ms"]:
6
  raise ValueError("REPO_TYPE must be either 'hf' for Hugging Face or 'ms' for ModelScope.")
 
13
 
14
 
15
  # 1. 定义本地路径和远程仓库ID
16
+ FUN_ASR_NANO_LOCAL_PATH = "./Fun-ASR/model"
17
+ SENSE_VOICE_SMALL_LOCAL_PATH = "./Fun-ASR/model/SenseVoiceSmall"
 
 
 
 
 
18
 
 
 
 
 
 
 
19
 
20
  if REPO_TYPE == "ms":
21
  FUN_ASR_NANO_REPO_ID = "FunAudioLLM/Fun-ASR-Nano-2512"
22
  SENSE_VOICE_SMALL_REPO_ID = "iic/SenseVoiceSmall"
 
23
  else:
24
  FUN_ASR_NANO_REPO_ID = "FunAudioLLM/Fun-ASR-Nano-2512"
25
  SENSE_VOICE_SMALL_REPO_ID = "FunAudioLLM/SenseVoiceSmall"
 
26
 
27
  # 2. 检查本地是否存在,不存在则下载
28
+ if not os.path.exists(FUN_ASR_NANO_LOCAL_PATH):
29
+ print(f"在下载模型 Fun-ASR-Nano 到 {FUN_ASR_NANO_LOCAL_PATH} ...")
30
+ snapshot_download(
31
+ repo_id=FUN_ASR_NANO_REPO_ID,
32
+ local_dir=FUN_ASR_NANO_LOCAL_PATH,
33
+ ignore_patterns=["*.onnx"], # 如果你不需要onnx文件,可以过滤掉以节省时间和空间
34
+ )
35
+ print("模型下载完毕!")
36
+ else:
37
+ print("检测到本地模型文件,跳过下载")
 
 
38
 
39
+
40
+ if not os.path.exists(SENSE_VOICE_SMALL_LOCAL_PATH):
41
+ print(f"正在下载模型 {SENSE_VOICE_SMALL_REPO_ID} 到 {SENSE_VOICE_SMALL_LOCAL_PATH} ...")
42
+ snapshot_download(
43
+ repo_id=SENSE_VOICE_SMALL_REPO_ID,
44
+ local_dir=SENSE_VOICE_SMALL_LOCAL_PATH,
45
+ ignore_patterns=["*.onnx"], # 如果你不需要onnx文件,可以过滤掉以节省时间和空间
46
+ )
47
+ print("模型下载完毕!")
48
+ else:
49
+ print("检测到本地模型文件,跳过下载。")
50
 
51
 
52
 
 
66
  from funasr import AutoModel
67
  from funasr.utils.postprocess_utils import rich_transcription_postprocess
68
 
69
+ # Model configurations for Hugging Face deployment
70
  FUN_ASR_NANO_MODEL_PATH_LIST = [
71
+ "Fun-ASR/model", # local path, ms
72
  ]
73
 
74
  SENSEVOICE_MODEL_PATH_LIST = [
75
+ "Fun-ASR/model/SenseVoiceSmall", # local path together with this hf space
76
  ]
77
 
78
  class LogCapture(io.StringIO):
 
93
 
94
 
95
  # Check for CUDA availability
96
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
97
+ logging.info(f"Using device: {device}")
98
 
99
  def download_audio(url, method_choice, proxy_url, proxy_username, proxy_password):
100
  """
 
406
  # Dictionary to store loaded models
407
  loaded_models = {}
408
 
409
+ @spaces.GPU()
410
  def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_password, pipeline_type, model_id, download_method, start_time=None, end_time=None, verbose=False):
411
  """
412
  Transcribes audio from a given source using SenseVoice.
 
427
  Yields:
428
  Tuple[str, str, str or None]: Metrics and messages, transcription text, path to transcription file.
429
  """
 
 
 
430
  try:
431
  if verbose:
432
  logging.getLogger().setLevel(logging.INFO)
 
478
 
479
  # Model caching
480
  model_key = (pipeline_type, model_id)
481
+ model = None
482
  if model_key in loaded_models:
483
  model = loaded_models[model_key]
484
  logging.info("Loaded model from cache")
 
488
  model=model_id,
489
  trust_remote_code=True,
490
  remote_code=f"./Fun-ASR/model.py",
491
+ vad_model="fsmn-vad",
492
  vad_kwargs={"max_single_segment_time": 30000},
493
+ device='cpu', # 初始化在cpu,然后推理的时候移到GPU,保证利用好zeroGPU?
494
  disable_update=True,
495
  hub='ms',
496
  )
 
498
  model = AutoModel(
499
  model=model_id,
500
  trust_remote_code=False,
501
+ vad_model="fsmn-vad",
502
  vad_kwargs={"max_single_segment_time": 30000},
503
+ device='cpu',
504
  disable_update=True,
505
  hub='ms',
506
  )
 
510
  yield verbose_messages + error_msg, "", None
511
  return
512
  loaded_models[model_key] = model
513
+
514
+ try:
515
+ model.to(device)
516
+ logging.info(f"Model moved to device: {device}")
517
+ except Exception as e:
518
+ logging.error(f"Error moving model to device: {str(e)}")
519
+ yield verbose_messages + f"Error moving model to device: {str(e)}", "", None
520
+ return
521
 
522
  # Perform the transcription
523
  start_time_perf = time.time()
 
545
  merge_vad=True,
546
  merge_length_s=15,
547
  )
548
+
549
+ model.to('cpu') # Move model back to CPU after inference to free GPU memory
550
 
551
  transcription = rich_transcription_postprocess(res[0]["text"])
552
  end_time_perf = time.time()