|
|
""" |
|
|
Model management for STT, TTS, and LLM |
|
|
Optimized for Hugging Face Zero GPU (H200) |
|
|
""" |
|
|
|
|
|
import os |
|
|
import torch |
|
|
import spaces |
|
|
from transformers import ( |
|
|
AutoModelForSpeechSeq2Seq, |
|
|
AutoProcessor, |
|
|
pipeline, |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer |
|
|
) |
|
|
from parler_tts import ParlerTTSForConditionalGeneration |
|
|
from transformers import AutoTokenizer as ParlerTokenizer |
|
|
import tempfile |
|
|
from typing import List, Dict |
|
|
import numpy as np |
|
|
from scipy.io import wavfile |
|
|
import soundfile as sf |
|
|
|
|
|
class ModelManager: |
|
|
def __init__(self): |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
|
|
|
|
|
self.whisper_pipe = None |
|
|
self.tts_model = None |
|
|
self.tts_tokenizer = None |
|
|
self.llm_model = None |
|
|
self.llm_tokenizer = None |
|
|
|
|
|
def load_whisper(self): |
|
|
"""Load Whisper model for STT""" |
|
|
if self.whisper_pipe is None: |
|
|
print("Loading Whisper model...") |
|
|
|
|
|
model_id = "openai/whisper-medium" |
|
|
|
|
|
model = AutoModelForSpeechSeq2Seq.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=self.torch_dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
use_safetensors=True |
|
|
) |
|
|
model.to(self.device) |
|
|
|
|
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
|
|
self.whisper_pipe = pipeline( |
|
|
"automatic-speech-recognition", |
|
|
model=model, |
|
|
tokenizer=processor.tokenizer, |
|
|
feature_extractor=processor.feature_extractor, |
|
|
torch_dtype=self.torch_dtype, |
|
|
device=self.device, |
|
|
chunk_length_s=30, |
|
|
batch_size=16, |
|
|
) |
|
|
print("Whisper model loaded successfully!") |
|
|
|
|
|
def load_tts(self): |
|
|
"""Load TTS model for text-to-speech""" |
|
|
if self.tts_model is None: |
|
|
print("Loading TTS model...") |
|
|
|
|
|
model_id = "parler-tts/parler-tts-tiny-v1" |
|
|
|
|
|
self.tts_model = ParlerTTSForConditionalGeneration.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=self.torch_dtype |
|
|
).to(self.device) |
|
|
|
|
|
self.tts_tokenizer = ParlerTokenizer.from_pretrained(model_id) |
|
|
print("TTS model loaded successfully!") |
|
|
|
|
|
def load_llm(self): |
|
|
"""Load LLM for conversation generation""" |
|
|
if self.llm_model is None: |
|
|
print("Loading LLM...") |
|
|
|
|
|
model_id = "meta-llama/Llama-3.2-3B-Instruct" |
|
|
|
|
|
self.llm_tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
self.llm_model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=self.torch_dtype, |
|
|
device_map="auto", |
|
|
low_cpu_mem_usage=True |
|
|
) |
|
|
print("LLM loaded successfully!") |
|
|
|
|
|
@spaces.GPU |
|
|
def speech_to_text(self, audio_path: str) -> str: |
|
|
"""Convert speech to text using Whisper - optimized for speed""" |
|
|
try: |
|
|
self.load_whisper() |
|
|
|
|
|
|
|
|
if not audio_path or not os.path.exists(audio_path): |
|
|
print(f"Audio file not found: {audio_path}") |
|
|
return "" |
|
|
|
|
|
|
|
|
if not audio_path.lower().endswith(('.wav', '.mp3', '.flac', '.m4a', '.ogg')): |
|
|
print(f"Invalid audio format: {audio_path}") |
|
|
return "" |
|
|
|
|
|
result = self.whisper_pipe( |
|
|
audio_path, |
|
|
return_timestamps=False, |
|
|
generate_kwargs={ |
|
|
"language": "english", |
|
|
"task": "transcribe", |
|
|
"num_beams": 1, |
|
|
"temperature": 0.0 |
|
|
} |
|
|
) |
|
|
|
|
|
return result["text"].strip() |
|
|
except Exception as e: |
|
|
print(f"Error in STT: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return "" |
|
|
|
|
|
@spaces.GPU |
|
|
def text_to_speech(self, text: str, accent: str = "American", speaker_name: str = None) -> str: |
|
|
"""Convert text to speech - optimized for speed with American accent""" |
|
|
try: |
|
|
self.load_tts() |
|
|
|
|
|
|
|
|
description = "A clear American male voice speaks at moderate pace with good enunciation." |
|
|
|
|
|
|
|
|
if len(text) > 200: |
|
|
text = text[:200] + "..." |
|
|
|
|
|
|
|
|
input_ids = self.tts_tokenizer(description, return_tensors="pt").input_ids.to(self.device) |
|
|
prompt_input_ids = self.tts_tokenizer(text, return_tensors="pt").input_ids.to(self.device) |
|
|
|
|
|
generation = self.tts_model.generate( |
|
|
input_ids=input_ids, |
|
|
prompt_input_ids=prompt_input_ids, |
|
|
attention_mask=torch.ones_like(input_ids), |
|
|
do_sample=False, |
|
|
num_beams=1 |
|
|
) |
|
|
|
|
|
audio_arr = generation.cpu().numpy().squeeze() |
|
|
|
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
|
|
|
|
|
|
|
|
audio_int16 = (audio_arr * 32767).astype(np.int16) |
|
|
|
|
|
|
|
|
wavfile.write( |
|
|
temp_file.name, |
|
|
self.tts_model.config.sampling_rate, |
|
|
audio_int16 |
|
|
) |
|
|
|
|
|
return temp_file.name |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in TTS: {e}") |
|
|
|
|
|
return None |
|
|
|
|
|
@spaces.GPU |
|
|
def generate_response( |
|
|
self, |
|
|
system_prompt: str, |
|
|
conversation_history: List[Dict], |
|
|
bot_name: str |
|
|
) -> str: |
|
|
"""Generate conversational response using LLM""" |
|
|
try: |
|
|
self.load_llm() |
|
|
|
|
|
|
|
|
messages = [{"role": "system", "content": system_prompt}] |
|
|
|
|
|
|
|
|
for msg in conversation_history[-6:]: |
|
|
messages.append({ |
|
|
"role": msg["role"], |
|
|
"content": msg["content"] |
|
|
}) |
|
|
|
|
|
|
|
|
inputs = self.llm_tokenizer.apply_chat_template( |
|
|
messages, |
|
|
return_tensors="pt", |
|
|
add_generation_prompt=True |
|
|
).to(self.device) |
|
|
|
|
|
outputs = self.llm_model.generate( |
|
|
inputs, |
|
|
max_new_tokens=200, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
pad_token_id=self.llm_tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
response = self.llm_tokenizer.decode( |
|
|
outputs[0][inputs.shape[1]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
return response.strip() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in LLM generation: {e}") |
|
|
return f"I understand. Could you tell me more about that?" |
|
|
|
|
|
@spaces.GPU |
|
|
def generate_feedback(self, prompt: str) -> str: |
|
|
"""Generate detailed feedback using LLM""" |
|
|
try: |
|
|
self.load_llm() |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": "You are an expert communication coach specializing in sales and professional communication. Provide specific, actionable feedback." |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": prompt |
|
|
} |
|
|
] |
|
|
|
|
|
inputs = self.llm_tokenizer.apply_chat_template( |
|
|
messages, |
|
|
return_tensors="pt", |
|
|
add_generation_prompt=True |
|
|
).to(self.device) |
|
|
|
|
|
outputs = self.llm_model.generate( |
|
|
inputs, |
|
|
max_new_tokens=500, |
|
|
temperature=0.7, |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
pad_token_id=self.llm_tokenizer.eos_token_id |
|
|
) |
|
|
|
|
|
feedback = self.llm_tokenizer.decode( |
|
|
outputs[0][inputs.shape[1]:], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
|
|
|
return feedback.strip() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in feedback generation: {e}") |
|
|
return "Unable to generate feedback at this time." |
|
|
|
|
|
|