|
|
"""
|
|
|
Unified model loading utility supporting ModelScope, HuggingFace and local path loading
|
|
|
"""
|
|
|
import importlib
|
|
|
import os
|
|
|
import logging
|
|
|
from pathlib import Path
|
|
|
import sys
|
|
|
import threading
|
|
|
from typing import Union, Optional, Dict, Any
|
|
|
import spaces
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
from funasr_detach import AutoModel
|
|
|
from transformers.models.auto import tokenization_auto, configuration_auto
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_model_download_cache = {}
|
|
|
_download_cache_lock = threading.Lock()
|
|
|
|
|
|
|
|
|
class ModelSource:
|
|
|
"""Model source enumeration"""
|
|
|
MODELSCOPE = "modelscope"
|
|
|
HUGGINGFACE = "huggingface"
|
|
|
LOCAL = "local"
|
|
|
AUTO = "auto"
|
|
|
|
|
|
|
|
|
class UnifiedModelLoader:
|
|
|
"""Unified model loader"""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _cached_snapshot_download(self, model_path: str, source: str, **kwargs) -> str:
|
|
|
"""
|
|
|
Cached version of snapshot_download to avoid repeated downloads
|
|
|
|
|
|
Args:
|
|
|
model_path: Model path or ID to download
|
|
|
source: Model source ('modelscope' or 'huggingface')
|
|
|
**kwargs: Additional arguments for snapshot_download
|
|
|
|
|
|
Returns:
|
|
|
Local path to downloaded model
|
|
|
"""
|
|
|
cache_key = (model_path, source, str(sorted(kwargs.items())))
|
|
|
|
|
|
|
|
|
with _download_cache_lock:
|
|
|
if cache_key in _model_download_cache:
|
|
|
cached_path = _model_download_cache[cache_key]
|
|
|
self.logger.info(f"Using cached download for {model_path} from {source}: {cached_path}")
|
|
|
return cached_path
|
|
|
|
|
|
|
|
|
if source == ModelSource.MODELSCOPE:
|
|
|
from modelscope.hub.snapshot_download import snapshot_download
|
|
|
local_path = snapshot_download(model_path, **kwargs)
|
|
|
elif source == ModelSource.HUGGINGFACE:
|
|
|
from huggingface_hub import snapshot_download
|
|
|
local_path = snapshot_download(model_path, **kwargs)
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported source for cached download: {source}")
|
|
|
|
|
|
|
|
|
with _download_cache_lock:
|
|
|
_model_download_cache[cache_key] = local_path
|
|
|
|
|
|
self.logger.info(f"Downloaded and cached {model_path} from {source}: {local_path}")
|
|
|
return local_path
|
|
|
|
|
|
def detect_model_source(self, model_path: str) -> str:
|
|
|
"""
|
|
|
Automatically detect model source
|
|
|
|
|
|
Args:
|
|
|
model_path: Model path or ID
|
|
|
|
|
|
Returns:
|
|
|
Model source type
|
|
|
"""
|
|
|
|
|
|
if os.path.exists(model_path) or os.path.isabs(model_path):
|
|
|
return ModelSource.LOCAL
|
|
|
|
|
|
|
|
|
if "/" in model_path and not model_path.startswith("http"):
|
|
|
|
|
|
if "modelscope" in model_path.lower() or self._is_modelscope_format(model_path):
|
|
|
return ModelSource.MODELSCOPE
|
|
|
else:
|
|
|
|
|
|
return ModelSource.HUGGINGFACE
|
|
|
|
|
|
return ModelSource.LOCAL
|
|
|
|
|
|
def _is_modelscope_format(self, model_path: str) -> bool:
|
|
|
"""Detect if it's ModelScope format model ID"""
|
|
|
|
|
|
|
|
|
modelscope_patterns = []
|
|
|
return any(pattern in model_path for pattern in modelscope_patterns)
|
|
|
|
|
|
@spaces.GPU
|
|
|
def load_transformers_model(
|
|
|
self,
|
|
|
model_path: str,
|
|
|
source: str = ModelSource.AUTO,
|
|
|
**kwargs
|
|
|
) -> tuple:
|
|
|
"""
|
|
|
Load Transformers model (for StepAudioTTS)
|
|
|
|
|
|
Args:
|
|
|
model_path: Model path or ID
|
|
|
source: Model source, auto means auto-detect
|
|
|
**kwargs: Other parameters
|
|
|
|
|
|
Returns:
|
|
|
(model, tokenizer) tuple
|
|
|
"""
|
|
|
if source == ModelSource.AUTO:
|
|
|
source = self.detect_model_source(model_path)
|
|
|
|
|
|
self.logger.info(f"Loading Transformers model from {source}: {model_path}")
|
|
|
|
|
|
try:
|
|
|
if source == ModelSource.LOCAL:
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
model_path,
|
|
|
torch_dtype=kwargs.get("torch_dtype"),
|
|
|
device_map=kwargs.get("device_map", "auto"),
|
|
|
trust_remote_code=True,
|
|
|
local_files_only=True
|
|
|
)
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
model_path,
|
|
|
trust_remote_code=True,
|
|
|
local_files_only=True
|
|
|
)
|
|
|
|
|
|
elif source == ModelSource.MODELSCOPE:
|
|
|
|
|
|
from modelscope import AutoModelForCausalLM as MSAutoModelForCausalLM
|
|
|
from modelscope import AutoTokenizer as MSAutoTokenizer
|
|
|
model_path = self._cached_snapshot_download(model_path, ModelSource.MODELSCOPE)
|
|
|
|
|
|
model = MSAutoModelForCausalLM.from_pretrained(
|
|
|
model_path,
|
|
|
torch_dtype=kwargs.get("torch_dtype"),
|
|
|
device_map=kwargs.get("device_map", "auto"),
|
|
|
trust_remote_code=True,
|
|
|
local_files_only=True
|
|
|
)
|
|
|
tokenizer = MSAutoTokenizer.from_pretrained(
|
|
|
model_path,
|
|
|
trust_remote_code=True,
|
|
|
local_files_only=True
|
|
|
)
|
|
|
|
|
|
elif source == ModelSource.HUGGINGFACE:
|
|
|
model_path = self._cached_snapshot_download(model_path, ModelSource.HUGGINGFACE)
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
|
model_path,
|
|
|
torch_dtype=kwargs.get("torch_dtype"),
|
|
|
device_map=kwargs.get("device_map", "auto"),
|
|
|
trust_remote_code=True,
|
|
|
local_files_only=True
|
|
|
)
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
model_path,
|
|
|
trust_remote_code=True,
|
|
|
local_files_only=True
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported model source: {source}")
|
|
|
|
|
|
self.logger.info(f"Successfully loaded model from {source}")
|
|
|
return model, tokenizer, model_path
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Failed to load model from {source}: {e}")
|
|
|
raise
|
|
|
|
|
|
@spaces.GPU
|
|
|
def load_funasr_model(
|
|
|
self,
|
|
|
repo_path: str,
|
|
|
model_path: str,
|
|
|
source: str = ModelSource.AUTO,
|
|
|
**kwargs
|
|
|
) -> AutoModel:
|
|
|
"""
|
|
|
Load FunASR model (for StepAudioTokenizer)
|
|
|
|
|
|
Args:
|
|
|
model_path: Model path or ID
|
|
|
source: Model source, auto means auto-detect
|
|
|
**kwargs: Other parameters
|
|
|
|
|
|
Returns:
|
|
|
FunASR AutoModel instance
|
|
|
"""
|
|
|
if source == ModelSource.AUTO:
|
|
|
source = self.detect_model_source(model_path)
|
|
|
|
|
|
self.logger.info(f"Loading FunASR model from {source}: {model_path}")
|
|
|
|
|
|
try:
|
|
|
|
|
|
model_revision = kwargs.pop("model_revision", "main")
|
|
|
|
|
|
|
|
|
if source == ModelSource.LOCAL:
|
|
|
model_hub = "local"
|
|
|
elif source == ModelSource.MODELSCOPE:
|
|
|
model_hub = "ms"
|
|
|
elif source == ModelSource.HUGGINGFACE:
|
|
|
model_hub = "hf"
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported model source: {source}")
|
|
|
|
|
|
|
|
|
model = AutoModel(
|
|
|
repo_path=repo_path,
|
|
|
model=model_path,
|
|
|
model_hub=model_hub,
|
|
|
model_revision=model_revision,
|
|
|
**kwargs
|
|
|
)
|
|
|
|
|
|
self.logger.info(f"Successfully loaded FunASR model from {source}")
|
|
|
return model
|
|
|
|
|
|
except Exception as e:
|
|
|
self.logger.error(f"Failed to load FunASR model from {source}: {e}")
|
|
|
raise
|
|
|
|
|
|
def resolve_model_path(
|
|
|
self,
|
|
|
base_path: str,
|
|
|
model_name: str,
|
|
|
source: str = ModelSource.AUTO
|
|
|
) -> str:
|
|
|
"""
|
|
|
Resolve model path
|
|
|
|
|
|
Args:
|
|
|
base_path: Base path
|
|
|
model_name: Model name
|
|
|
source: Model source
|
|
|
|
|
|
Returns:
|
|
|
Resolved model path
|
|
|
"""
|
|
|
if source == ModelSource.AUTO:
|
|
|
|
|
|
local_path = os.path.join(base_path, model_name)
|
|
|
if os.path.exists(local_path):
|
|
|
return local_path
|
|
|
|
|
|
|
|
|
return model_name
|
|
|
|
|
|
elif source == ModelSource.LOCAL:
|
|
|
return os.path.join(base_path, model_name)
|
|
|
|
|
|
else:
|
|
|
|
|
|
return model_name
|
|
|
|
|
|
|
|
|
|
|
|
model_loader = UnifiedModelLoader() |