Spaces:
Sleeping
Sleeping
| """HuggingFace inference wrapper for MOSS-Audio.""" | |
| from __future__ import annotations | |
| import os | |
| from typing import Optional | |
| import torch | |
| from src.audio_io import load_audio | |
| from src.modeling_moss_audio import MossAudioModel | |
| from src.processing_moss_audio import MossAudioProcessor | |
| DEFAULT_MODEL_ID = "OpenMOSS-Team/MOSS-Audio-8B-Thinking" | |
| def read_env_model_id() -> str: | |
| return os.environ.get("MOSS_AUDIO_MODEL_ID", DEFAULT_MODEL_ID) | |
| def resolve_device() -> str: | |
| if torch.cuda.is_available(): | |
| return "cuda:0" | |
| return "cpu" | |
| class MossAudioHFInference: | |
| """Thin wrapper that loads model + processor and exposes a single | |
| ``generate`` method for both audio-grounded and text-only queries.""" | |
| def __init__( | |
| self, | |
| model_name_or_path: str = DEFAULT_MODEL_ID, | |
| device: str = "cuda:0", | |
| torch_dtype: str = "auto", | |
| enable_time_marker: bool = True, | |
| ): | |
| self.device = device | |
| load_kwargs = { | |
| "trust_remote_code": True, | |
| "torch_dtype": torch_dtype, | |
| "low_cpu_mem_usage": True, | |
| } | |
| load_kwargs["device_map"] = {"": device} | |
| self.model = MossAudioModel.from_pretrained( | |
| model_name_or_path, | |
| **load_kwargs, | |
| ) | |
| self.model.eval() | |
| self.processor = MossAudioProcessor.from_pretrained( | |
| model_name_or_path, | |
| trust_remote_code=True, | |
| enable_time_marker=enable_time_marker, | |
| ) | |
| def generate( | |
| self, | |
| question: str, | |
| audio_path: Optional[str] = None, | |
| max_new_tokens: int = 1024, | |
| num_beams: int = 1, | |
| do_sample: bool = True, | |
| temperature: float = 1.0, | |
| top_p: float = 1.0, | |
| top_k: int = 50, | |
| ) -> str: | |
| if audio_path is not None: | |
| raw_audio = load_audio(audio_path, sample_rate=self.processor.config.mel_sr) | |
| inputs = self.processor(text=question, audios=[raw_audio], return_tensors="pt") | |
| else: | |
| inputs = self.processor(text=question, return_tensors="pt") | |
| inputs = inputs.to(self.model.device) | |
| if inputs.get("audio_data") is not None: | |
| inputs["audio_data"] = inputs["audio_data"].to(self.model.dtype) | |
| audio_input_mask = inputs["input_ids"] == self.processor.audio_token_id | |
| inputs["audio_input_mask"] = audio_input_mask | |
| gen_kwargs = dict( | |
| max_new_tokens=max_new_tokens, | |
| num_beams=num_beams, | |
| use_cache=True, | |
| ) | |
| if do_sample: | |
| gen_kwargs.update( | |
| do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k | |
| ) | |
| else: | |
| gen_kwargs["do_sample"] = False | |
| generated_ids = self.model.generate(**inputs, **gen_kwargs) | |
| input_len = inputs["input_ids"].shape[1] | |
| return self.processor.decode( | |
| generated_ids[0, input_len:], skip_special_tokens=True | |
| ) | |
| __all__ = ["MossAudioHFInference", "read_env_model_id", "resolve_device"] | |