| """HuggingFace inference wrapper for MOSS-Audio.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| from typing import Optional |
|
|
| import torch |
|
|
| from src.audio_io import load_audio |
| from src.modeling_moss_audio import MossAudioModel |
| from src.processing_moss_audio import MossAudioProcessor |
|
|
| DEFAULT_MODEL_ID = "OpenMOSS-Team/MOSS-Audio-8B-Thinking" |
|
|
|
|
| def read_env_model_id() -> str: |
| return os.environ.get("MOSS_AUDIO_MODEL_ID", DEFAULT_MODEL_ID) |
|
|
|
|
| def resolve_device() -> str: |
| if torch.cuda.is_available(): |
| return "cuda:0" |
| return "cpu" |
|
|
|
|
| class MossAudioHFInference: |
| """Thin wrapper that loads model + processor and exposes a single |
| ``generate`` method for both audio-grounded and text-only queries.""" |
|
|
| def __init__( |
| self, |
| model_name_or_path: str = DEFAULT_MODEL_ID, |
| device: str = "cuda:0", |
| torch_dtype: str = "auto", |
| enable_time_marker: bool = True, |
| ): |
| self.device = device |
| load_kwargs = { |
| "trust_remote_code": True, |
| "torch_dtype": torch_dtype, |
| "low_cpu_mem_usage": True, |
| } |
| load_kwargs["device_map"] = {"": device} |
|
|
| self.model = MossAudioModel.from_pretrained( |
| model_name_or_path, |
| **load_kwargs, |
| ) |
| self.model.eval() |
| self.processor = MossAudioProcessor.from_pretrained( |
| model_name_or_path, |
| trust_remote_code=True, |
| enable_time_marker=enable_time_marker, |
| ) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| question: str, |
| audio_path: Optional[str] = None, |
| max_new_tokens: int = 1024, |
| num_beams: int = 1, |
| do_sample: bool = True, |
| temperature: float = 1.0, |
| top_p: float = 1.0, |
| top_k: int = 50, |
| ) -> str: |
| if audio_path is not None: |
| raw_audio = load_audio(audio_path, sample_rate=self.processor.config.mel_sr) |
| inputs = self.processor(text=question, audios=[raw_audio], return_tensors="pt") |
| else: |
| inputs = self.processor(text=question, return_tensors="pt") |
|
|
| inputs = inputs.to(self.model.device) |
| if inputs.get("audio_data") is not None: |
| inputs["audio_data"] = inputs["audio_data"].to(self.model.dtype) |
|
|
| audio_input_mask = inputs["input_ids"] == self.processor.audio_token_id |
| inputs["audio_input_mask"] = audio_input_mask |
|
|
| gen_kwargs = dict( |
| max_new_tokens=max_new_tokens, |
| num_beams=num_beams, |
| use_cache=True, |
| ) |
| if do_sample: |
| gen_kwargs.update( |
| do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k |
| ) |
| else: |
| gen_kwargs["do_sample"] = False |
|
|
| generated_ids = self.model.generate(**inputs, **gen_kwargs) |
|
|
| input_len = inputs["input_ids"].shape[1] |
| return self.processor.decode( |
| generated_ids[0, input_len:], skip_special_tokens=True |
| ) |
|
|
|
|
| __all__ = ["MossAudioHFInference", "read_env_model_id", "resolve_device"] |
|
|