rosassebastian2003 commited on
Commit
81ff557
·
1 Parent(s): f2527c6

Se volvio a el approach de usar transformers

Browse files
Files changed (2) hide show
  1. handler.py +46 -118
  2. requirements.txt +1 -1
handler.py CHANGED
@@ -1,125 +1,53 @@
1
- from transformers import pipeline
2
- import torch
3
- import base64
4
  from typing import Dict, List, Any
5
- import io
6
- import scipy.io.wavfile as wavfile
7
- import os
8
- import tempfile
9
- import numpy as np
10
-
11
- # Nombre del modelo (usado como fallback si 'path' no se proporciona)
12
- MODEL_NAME = "Qwen/Qwen3-Omni-30B-A3B-Instruct"
13
 
14
  class EndpointHandler():
15
- def __init__(self, path=""):
16
-
17
- # 1. Configuraciones críticas para la carga del modelo MoE y la funcionalidad de voz
18
- model_kwargs = {
19
- "device_map": "auto", # Optimización para la distribución de pesos en GPU [1]
20
- "torch_dtype": torch.bfloat16 if torch.cuda.is_available() else None,
21
- "enable_audio_output": True # Clave esencial para cargar el componente Talker (generador de voz) [4]
22
- }
23
-
24
- # 2. Carga del pipeline genérico de generación de texto (el wrapper para LLM multimodales) [3]
25
- self.pipeline = pipeline(
26
- task="text-generation",
27
- model=path or MODEL_NAME,
28
- **model_kwargs # Inyección de los parámetros específicos de Qwen3
29
- )
30
-
31
- # 3. System prompt obligatorio para Qwen3-Omni para generar audio natural [4]
32
- self.system_prompt = (
33
- "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, "
34
- "capable of perceiving auditory and visual inputs, as well as generating text and speech."
35
  )
36
-
37
- # 4. Tasa de muestreo del modelo (necesaria para la serialización de audio en __call__)
38
- self.sampling_rate = getattr(self.pipeline.model.config, 'sampling_rate', 24000)
39
-
40
-
41
- def _handle_audio_input(self, data: Dict[str, Any]) -> str:
42
- """ Decodifica la entrada de audio Base64 y la guarda temporalmente como un archivo WAV. """
43
- audio_data_base64 = data.get("audio_data")
44
- if not audio_data_base64:
45
- return None
46
-
47
- temp_file_path = None
48
- try:
49
- audio_bytes = base64.b64decode(audio_data_base64)
50
- # Guardar en un archivo temporal para que el pipeline lo pueda procesar [5]
51
- temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
52
- temp_file.write(audio_bytes)
53
- temp_file.close()
54
- temp_file_path = temp_file.name
55
- return temp_file_path
56
- except Exception as e:
57
- if temp_file_path and os.path.exists(temp_file_path):
58
- os.remove(temp_file_path)
59
- raise ValueError(f"Error al decodificar y guardar el audio Base64: {e}")
60
-
61
- def _handle_audio_output(self, generated_audio: torch.Tensor, sampling_rate: int) -> str:
62
- """ Convierte el tensor de audio de salida a un buffer WAV y lo codifica en Base64. """
63
- audio_array = generated_audio.cpu().numpy().squeeze()
64
- if audio_array.dtype!= np.float32:
65
- audio_array = audio_array.astype(np.float32)
66
-
67
- with io.BytesIO() as buffer:
68
- # Escribir el array como WAV [2]
69
- wavfile.write(buffer, rate=sampling_rate, data=audio_array)
70
- buffer.seek(0)
71
-
72
- # Codificar a Base64 para la respuesta JSON
73
- encoded_audio = base64.b64encode(buffer.read()).decode('utf-8')
74
- return encoded_audio
75
 
76
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
77
- prompt = data.get("inputs")
78
- if not prompt:
79
- raise ValueError("El campo 'inputs' (prompt de texto) es obligatorio.")
80
-
81
- generation_kwargs = data.get("parameters", {})
82
- audio_file_path = None
83
-
84
- try:
85
- # 1. Manejo de I/O de audio (Base64 -> Archivo Temporal)
86
- audio_file_path = self._handle_audio_input(data)
87
-
88
- # 2. El pipeline espera una lista de entradas multimodales (Texto o Audio)
89
- inputs_list = [prompt]
90
- if audio_file_path:
91
- inputs_list.append(audio_file_path)
92
-
93
- # 3. Configuración de generación
94
- generation_kwargs.update({
95
- "system_prompt": self.system_prompt, # Requerido para la calidad de la voz [4]
96
- "return_audio": True, # Solicitamos que la salida contenga el tensor de audio [4]
97
- "max_new_tokens": generation_kwargs.get("max_new_tokens", 512),
98
- })
99
-
100
- # 4. Ejecutar el pipeline
101
- raw_output = self.pipeline(inputs_list, **generation_kwargs)
102
-
103
- # El pipeline devuelve una lista de diccionarios, extraemos el primer resultado
104
- response = raw_output
105
-
106
- final_response = {
107
- "generated_text": response.get("generated_text"),
108
- "audio_output": None
109
- }
110
-
111
- # 5. Post-procesamiento (Tensor -> Base64-WAV)
112
- if "audio_array" in response:
113
- encoded_audio = self._handle_audio_output(response["audio_array"], self.sampling_rate)
114
- final_response["audio_output"] = encoded_audio
115
-
116
- return [final_response]
117
-
118
- except Exception as e:
119
- # Manejo de errores
120
- return [{"error": str(e)}]
121
 
122
- finally:
123
- # 6. Limpieza de archivos temporales (Mantenimiento crítico)
124
- if audio_file_path and os.path.exists(audio_file_path):
125
- os.remove(audio_file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Dict, List, Any
2
+ import soundfile as sf
3
+ from transformers import Qwen3OmniMoeForConditionalGeneration, Qwen3OmniMoeProcessor
4
+ from qwen_omni_utils import process_mm_info
 
 
 
 
 
5
 
6
  class EndpointHandler():
7
+ def __init__(self, path="./"):
8
+ self.model = Qwen3OmniMoeForConditionalGeneration.from_pretrained(
9
+ path,
10
+ dtype="auto",
11
+ device_map="auto",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  )
13
+ self.processor = Qwen3OmniMoeProcessor.from_pretrained(path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
16
+ messages = data.get("messages", [])
17
+ use_audio_in_video = data.get("use_audio_in_video", True)
18
+ speaker = data.get("speaker", "Ethan")
19
+
20
+ text = self.processor.apply_chat_template(
21
+ messages,
22
+ tokenize=False,
23
+ add_generation_prompt=True,
24
+ )
25
+ audios, images, videos = process_mm_info(messages, use_audio_in_video=use_audio_in_video)
26
+ inputs = self.processor(
27
+ text=text,
28
+ audio=audios,
29
+ images=images,
30
+ videos=videos,
31
+ return_tensors="pt",
32
+ padding=True,
33
+ use_audio_in_video=use_audio_in_video
34
+ )
35
+ inputs = inputs.to(self.model.device).to(self.model.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ text_ids, audio = self.model.generate(
38
+ **inputs,
39
+ speaker=speaker,
40
+ thinker_return_dict_in_generate=True,
41
+ use_audio_in_video=use_audio_in_video
42
+ )
43
+ text_output = self.processor.batch_decode(
44
+ text_ids.sequences[:, inputs["input_ids"].shape[1]:],
45
+ skip_special_tokens=True,
46
+ clean_up_tokenization_spaces=False
47
+ )
48
+ result = {"generated_text": text_output}
49
+ if audio is not None:
50
+ # Guarda el audio en un archivo temporal y retorna la ruta
51
+ sf.write("output.wav", audio.reshape(-1).detach().cpu().numpy(), samplerate=24000)
52
+ result["audio_path"] = "output.wav"
53
+ return [result]
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  soundfile
2
- transformers>=4.51.0
3
  torch
4
  qwen-omni-utils
5
  torchvision
 
1
  soundfile
2
+ transformers
3
  torch
4
  qwen-omni-utils
5
  torchvision