File size: 3,080 Bytes
15d6fac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""HuggingFace inference wrapper for MOSS-Audio."""

from __future__ import annotations

import os
from typing import Optional

import torch

from src.audio_io import load_audio
from src.modeling_moss_audio import MossAudioModel
from src.processing_moss_audio import MossAudioProcessor

DEFAULT_MODEL_ID = "OpenMOSS-Team/MOSS-Audio-8B-Thinking"


def read_env_model_id() -> str:
    return os.environ.get("MOSS_AUDIO_MODEL_ID", DEFAULT_MODEL_ID)


def resolve_device() -> str:
    if torch.cuda.is_available():
        return "cuda:0"
    return "cpu"


class MossAudioHFInference:
    """Thin wrapper that loads model + processor and exposes a single
    ``generate`` method for both audio-grounded and text-only queries."""

    def __init__(
        self,
        model_name_or_path: str = DEFAULT_MODEL_ID,
        device: str = "cuda:0",
        torch_dtype: str = "auto",
        enable_time_marker: bool = True,
    ):
        self.device = device
        load_kwargs = {
            "trust_remote_code": True,
            "torch_dtype": torch_dtype,
            "low_cpu_mem_usage": True,
        }
        load_kwargs["device_map"] = {"": device}

        self.model = MossAudioModel.from_pretrained(
            model_name_or_path,
            **load_kwargs,
        )
        self.model.eval()
        self.processor = MossAudioProcessor.from_pretrained(
            model_name_or_path,
            trust_remote_code=True,
            enable_time_marker=enable_time_marker,
        )

    @torch.no_grad()
    def generate(
        self,
        question: str,
        audio_path: Optional[str] = None,
        max_new_tokens: int = 1024,
        num_beams: int = 1,
        do_sample: bool = True,
        temperature: float = 1.0,
        top_p: float = 1.0,
        top_k: int = 50,
    ) -> str:
        if audio_path is not None:
            raw_audio = load_audio(audio_path, sample_rate=self.processor.config.mel_sr)
            inputs = self.processor(text=question, audios=[raw_audio], return_tensors="pt")
        else:
            inputs = self.processor(text=question, return_tensors="pt")

        inputs = inputs.to(self.model.device)
        if inputs.get("audio_data") is not None:
            inputs["audio_data"] = inputs["audio_data"].to(self.model.dtype)

        audio_input_mask = inputs["input_ids"] == self.processor.audio_token_id
        inputs["audio_input_mask"] = audio_input_mask

        gen_kwargs = dict(
            max_new_tokens=max_new_tokens,
            num_beams=num_beams,
            use_cache=True,
        )
        if do_sample:
            gen_kwargs.update(
                do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k
            )
        else:
            gen_kwargs["do_sample"] = False

        generated_ids = self.model.generate(**inputs, **gen_kwargs)

        input_len = inputs["input_ids"].shape[1]
        return self.processor.decode(
            generated_ids[0, input_len:], skip_special_tokens=True
        )


__all__ = ["MossAudioHFInference", "read_env_model_id", "resolve_device"]