FFomy commited on
Commit
a483939
·
verified ·
1 Parent(s): 399aaa2

seperate .to(device) ?

Browse files
Files changed (1) hide show
  1. app.py +46 -43
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import os
2
  import spaces
3
-
4
  REPO_TYPE = "hf"
5
  if REPO_TYPE not in ["hf", "ms"]:
6
  raise ValueError("REPO_TYPE must be either 'hf' for Hugging Face or 'ms' for ModelScope.")
@@ -13,40 +13,48 @@ else:
13
 
14
 
15
  # 1. 定义本地路径和远程仓库ID
16
- FUN_ASR_NANO_LOCAL_PATH = "./Fun-ASR/model"
17
- SENSE_VOICE_SMALL_LOCAL_PATH = "./Fun-ASR/model/SenseVoiceSmall"
 
 
 
 
 
18
 
 
 
 
 
 
 
19
 
20
  if REPO_TYPE == "ms":
21
  FUN_ASR_NANO_REPO_ID = "FunAudioLLM/Fun-ASR-Nano-2512"
22
  SENSE_VOICE_SMALL_REPO_ID = "iic/SenseVoiceSmall"
 
23
  else:
24
  FUN_ASR_NANO_REPO_ID = "FunAudioLLM/Fun-ASR-Nano-2512"
25
  SENSE_VOICE_SMALL_REPO_ID = "FunAudioLLM/SenseVoiceSmall"
 
26
 
27
  # 2. 检查本地是否存在,不存在则下载
28
- if not os.path.exists(FUN_ASR_NANO_LOCAL_PATH):
29
- print(f"在下载模型 Fun-ASR-Nano 到 {FUN_ASR_NANO_LOCAL_PATH} ...")
30
- snapshot_download(
31
- repo_id=FUN_ASR_NANO_REPO_ID,
32
- local_dir=FUN_ASR_NANO_LOCAL_PATH,
33
- ignore_patterns=["*.onnx"], # 如果你不需要onnx文件,可以过滤掉以节省时间和空间
34
- )
35
- print("模型下载完毕!")
36
- else:
37
- print("检测到本地模型文件,跳过下载")
38
-
 
39
 
40
- if not os.path.exists(SENSE_VOICE_SMALL_LOCAL_PATH):
41
- print(f"正在下载模型 {SENSE_VOICE_SMALL_REPO_ID} 到 {SENSE_VOICE_SMALL_LOCAL_PATH} ...")
42
- snapshot_download(
43
- repo_id=SENSE_VOICE_SMALL_REPO_ID,
44
- local_dir=SENSE_VOICE_SMALL_LOCAL_PATH,
45
- ignore_patterns=["*.onnx"], # 如果你不需要onnx文件,可以过滤掉以节省时间和空间
46
- )
47
- print("模型下载完毕!")
48
- else:
49
- print("检测到本地模型文件,跳过下载。")
50
 
51
 
52
 
@@ -66,13 +74,13 @@ import importlib
66
  from funasr import AutoModel
67
  from funasr.utils.postprocess_utils import rich_transcription_postprocess
68
 
69
- # Model configurations for Hugging Face deployment
70
  FUN_ASR_NANO_MODEL_PATH_LIST = [
71
- "Fun-ASR/model", # local path, ms
72
  ]
73
 
74
  SENSEVOICE_MODEL_PATH_LIST = [
75
- "Fun-ASR/model/SenseVoiceSmall", # local path together with this hf space
76
  ]
77
 
78
  class LogCapture(io.StringIO):
@@ -406,7 +414,7 @@ def get_model_options(pipeline_type):
406
  # Dictionary to store loaded models
407
  loaded_models = {}
408
 
409
- @spaces.GPU()
410
  def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_password, pipeline_type, model_id, download_method, start_time=None, end_time=None, verbose=False):
411
  """
412
  Transcribes audio from a given source using SenseVoice.
@@ -478,7 +486,6 @@ def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_pa
478
 
479
  # Model caching
480
  model_key = (pipeline_type, model_id)
481
- model = None
482
  if model_key in loaded_models:
483
  model = loaded_models[model_key]
484
  logging.info("Loaded model from cache")
@@ -488,9 +495,9 @@ def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_pa
488
  model=model_id,
489
  trust_remote_code=True,
490
  remote_code=f"./Fun-ASR/model.py",
491
- vad_model="fsmn-vad",
492
  vad_kwargs={"max_single_segment_time": 30000},
493
- device='cpu', # 初始化在cpu,然后推理的时候移到GPU,保证利用好zeroGPU?
494
  disable_update=True,
495
  hub='ms',
496
  )
@@ -498,9 +505,9 @@ def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_pa
498
  model = AutoModel(
499
  model=model_id,
500
  trust_remote_code=False,
501
- vad_model="fsmn-vad",
502
  vad_kwargs={"max_single_segment_time": 30000},
503
- device='cpu',
504
  disable_update=True,
505
  hub='ms',
506
  )
@@ -511,14 +518,9 @@ def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_pa
511
  return
512
  loaded_models[model_key] = model
513
 
514
- try:
515
- model.to(device)
516
- logging.info(f"Model moved to device: {device}")
517
- except Exception as e:
518
- logging.error(f"Error moving model to device: {str(e)}")
519
- yield verbose_messages + f"Error moving model to device: {str(e)}", "", None
520
- return
521
-
522
  # Perform the transcription
523
  start_time_perf = time.time()
524
 
@@ -545,8 +547,9 @@ def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_pa
545
  merge_vad=True,
546
  merge_length_s=15,
547
  )
548
-
549
- model.to('cpu') # Move model back to CPU after inference to free GPU memory
 
550
 
551
  transcription = rich_transcription_postprocess(res[0]["text"])
552
  end_time_perf = time.time()
 
1
  import os
2
  import spaces
3
+ # only debug for hf now
4
  REPO_TYPE = "hf"
5
  if REPO_TYPE not in ["hf", "ms"]:
6
  raise ValueError("REPO_TYPE must be either 'hf' for Hugging Face or 'ms' for ModelScope.")
 
13
 
14
 
15
  # 1. 定义本地路径和远程仓库ID
16
+ MODEL_CACHE_DIR = "./models"
17
+ FUN_ASR_NANO_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "Fun-ASR-Nano")
18
+ SENSE_VOICE_SMALL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "SenseVoiceSmall")
19
+ VAD_MODEL_LOCAL_PATH = os.path.join(MODEL_CACHE_DIR, "fsmn-vad")
20
+
21
+ # 创建模型缓存目录
22
+ os.makedirs(MODEL_CACHE_DIR, exist_ok=True)
23
 
24
+ # 设置ModelScope环境变量以使用本地缓存
25
+ os.environ['MODELSCOPE_CACHE'] = MODEL_CACHE_DIR
26
+ # 禁用远程下载,强制使用本地模型(可选,如果想要确保只使用本地模型)
27
+ # os.environ['MODELSCOPE_DISABLE_REMOTE'] = '1'
28
+
29
+ print(f"ModelScope缓存目录设置为: {MODEL_CACHE_DIR}")
30
 
31
  if REPO_TYPE == "ms":
32
  FUN_ASR_NANO_REPO_ID = "FunAudioLLM/Fun-ASR-Nano-2512"
33
  SENSE_VOICE_SMALL_REPO_ID = "iic/SenseVoiceSmall"
34
+ VAD_MODEL_REPO_ID = "iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
35
  else:
36
  FUN_ASR_NANO_REPO_ID = "FunAudioLLM/Fun-ASR-Nano-2512"
37
  SENSE_VOICE_SMALL_REPO_ID = "FunAudioLLM/SenseVoiceSmall"
38
+ VAD_MODEL_REPO_ID = "funasr/fsmn-vad"
39
 
40
  # 2. 检查本地是否存在,不存在则下载
41
+ def download_model_if_not_exists(repo_id, local_path, model_name):
42
+ """如果本地模型不存,则下载模型"""
43
+ if not os.path.exists(local_path):
44
+ print(f"正在下载模型 {model_name} 到 {local_path} ...")
45
+ snapshot_download(
46
+ repo_id=repo_id,
47
+ local_dir=local_path,
48
+ ignore_patterns=["*.onnx"], # 如果你不需要onnx文件,可以过滤掉以节省时间和空间
49
+ )
50
+ print(f"{model_name} 模型下载完毕!")
51
+ else:
52
+ print(f"检测到本地 {model_name} 模型文件,跳过下载。")
53
 
54
+ # 下载所有需要的模型
55
+ download_model_if_not_exists(FUN_ASR_NANO_REPO_ID, FUN_ASR_NANO_LOCAL_PATH, "Fun-ASR-Nano")
56
+ download_model_if_not_exists(SENSE_VOICE_SMALL_REPO_ID, SENSE_VOICE_SMALL_LOCAL_PATH, "SenseVoiceSmall")
57
+ download_model_if_not_exists(VAD_MODEL_REPO_ID, VAD_MODEL_LOCAL_PATH, "VAD Model")
 
 
 
 
 
 
58
 
59
 
60
 
 
74
  from funasr import AutoModel
75
  from funasr.utils.postprocess_utils import rich_transcription_postprocess
76
 
77
+ # Model configurations for local deployment
78
  FUN_ASR_NANO_MODEL_PATH_LIST = [
79
+ FUN_ASR_NANO_LOCAL_PATH, # local path
80
  ]
81
 
82
  SENSEVOICE_MODEL_PATH_LIST = [
83
+ SENSE_VOICE_SMALL_LOCAL_PATH, # local path
84
  ]
85
 
86
  class LogCapture(io.StringIO):
 
414
  # Dictionary to store loaded models
415
  loaded_models = {}
416
 
417
+ @spaces.GPU(duration=40)
418
  def transcribe_audio(audio_input, audio_url, proxy_url, proxy_username, proxy_password, pipeline_type, model_id, download_method, start_time=None, end_time=None, verbose=False):
419
  """
420
  Transcribes audio from a given source using SenseVoice.
 
486
 
487
  # Model caching
488
  model_key = (pipeline_type, model_id)
 
489
  if model_key in loaded_models:
490
  model = loaded_models[model_key]
491
  logging.info("Loaded model from cache")
 
495
  model=model_id,
496
  trust_remote_code=True,
497
  remote_code=f"./Fun-ASR/model.py",
498
+ vad_model=VAD_MODEL_LOCAL_PATH, # Use local VAD model path
499
  vad_kwargs={"max_single_segment_time": 30000},
500
+ device=device,
501
  disable_update=True,
502
  hub='ms',
503
  )
 
505
  model = AutoModel(
506
  model=model_id,
507
  trust_remote_code=False,
508
+ vad_model=VAD_MODEL_LOCAL_PATH, # Use local VAD model path
509
  vad_kwargs={"max_single_segment_time": 30000},
510
+ device=device,
511
  disable_update=True,
512
  hub='ms',
513
  )
 
518
  return
519
  loaded_models[model_key] = model
520
 
521
+ # move seperately?
522
+ model.model.to(device)
523
+ model.vad_model.to(device)
 
 
 
 
 
524
  # Perform the transcription
525
  start_time_perf = time.time()
526
 
 
547
  merge_vad=True,
548
  merge_length_s=15,
549
  )
550
+
551
+ model.model.to("cpu")
552
+ model.vad_model.to("cpu")
553
 
554
  transcription = rich_transcription_postprocess(res[0]["text"])
555
  end_time_perf = time.time()