Yukti / models.py
Revrse's picture
Upload 2 files
1724ab6 verified
"""
Model management for STT, TTS, and LLM
Optimized for Hugging Face Zero GPU (H200)
"""
import os
import torch
import spaces
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline,
AutoModelForCausalLM,
AutoTokenizer
)
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer as ParlerTokenizer
import tempfile
from typing import List, Dict
import numpy as np
from scipy.io import wavfile
import soundfile as sf
class ModelManager:
def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Models will be loaded lazily
self.whisper_pipe = None
self.tts_model = None
self.tts_tokenizer = None
self.llm_model = None
self.llm_tokenizer = None
def load_whisper(self):
"""Load Whisper model for STT"""
if self.whisper_pipe is None:
print("Loading Whisper model...")
# Using medium model for better speed/accuracy balance
model_id = "openai/whisper-medium"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
model.to(self.device)
processor = AutoProcessor.from_pretrained(model_id)
self.whisper_pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=self.torch_dtype,
device=self.device,
chunk_length_s=30,
batch_size=16,
)
print("Whisper model loaded successfully!")
def load_tts(self):
"""Load TTS model for text-to-speech"""
if self.tts_model is None:
print("Loading TTS model...")
# Using smaller, faster TTS model
model_id = "parler-tts/parler-tts-tiny-v1"
self.tts_model = ParlerTTSForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=self.torch_dtype
).to(self.device)
self.tts_tokenizer = ParlerTokenizer.from_pretrained(model_id)
print("TTS model loaded successfully!")
def load_llm(self):
"""Load LLM for conversation generation"""
if self.llm_model is None:
print("Loading LLM...")
# Using Llama 3.2 3B - smaller and faster than 7B models
model_id = "meta-llama/Llama-3.2-3B-Instruct"
self.llm_tokenizer = AutoTokenizer.from_pretrained(model_id)
self.llm_model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=self.torch_dtype,
device_map="auto",
low_cpu_mem_usage=True
)
print("LLM loaded successfully!")
@spaces.GPU
def speech_to_text(self, audio_path: str) -> str:
"""Convert speech to text using Whisper - optimized for speed"""
try:
self.load_whisper()
# Validate audio file exists and has correct format
if not audio_path or not os.path.exists(audio_path):
print(f"Audio file not found: {audio_path}")
return ""
# Check file extension
if not audio_path.lower().endswith(('.wav', '.mp3', '.flac', '.m4a', '.ogg')):
print(f"Invalid audio format: {audio_path}")
return ""
result = self.whisper_pipe(
audio_path,
return_timestamps=False,
generate_kwargs={
"language": "english",
"task": "transcribe",
"num_beams": 1, # Faster
"temperature": 0.0 # More deterministic
}
)
return result["text"].strip()
except Exception as e:
print(f"Error in STT: {e}")
import traceback
traceback.print_exc()
return ""
@spaces.GPU
def text_to_speech(self, text: str, accent: str = "American", speaker_name: str = None) -> str:
"""Convert text to speech - optimized for speed with American accent"""
try:
self.load_tts()
# Simplified: Just use one clear American voice for speed
description = "A clear American male voice speaks at moderate pace with good enunciation."
# Limit text length for faster generation
if len(text) > 200:
text = text[:200] + "..."
# Generate audio with optimized settings
input_ids = self.tts_tokenizer(description, return_tensors="pt").input_ids.to(self.device)
prompt_input_ids = self.tts_tokenizer(text, return_tensors="pt").input_ids.to(self.device)
generation = self.tts_model.generate(
input_ids=input_ids,
prompt_input_ids=prompt_input_ids,
attention_mask=torch.ones_like(input_ids),
do_sample=False, # Faster, deterministic
num_beams=1 # Faster generation
)
audio_arr = generation.cpu().numpy().squeeze()
# Save to temporary file using scipy
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
# Normalize audio to int16 range
audio_int16 = (audio_arr * 32767).astype(np.int16)
# Save using scipy.io.wavfile
wavfile.write(
temp_file.name,
self.tts_model.config.sampling_rate,
audio_int16
)
return temp_file.name
except Exception as e:
print(f"Error in TTS: {e}")
# Return a silent audio file as fallback
return None
@spaces.GPU
def generate_response(
self,
system_prompt: str,
conversation_history: List[Dict],
bot_name: str
) -> str:
"""Generate conversational response using LLM"""
try:
self.load_llm()
# Format conversation for the model
messages = [{"role": "system", "content": system_prompt}]
# Add conversation history
for msg in conversation_history[-6:]: # Keep last 6 messages for context
messages.append({
"role": msg["role"],
"content": msg["content"]
})
# Format conversation for Llama
inputs = self.llm_tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True
).to(self.device)
outputs = self.llm_model.generate(
inputs,
max_new_tokens=200,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=self.llm_tokenizer.eos_token_id
)
response = self.llm_tokenizer.decode(
outputs[0][inputs.shape[1]:],
skip_special_tokens=True
)
return response.strip()
except Exception as e:
print(f"Error in LLM generation: {e}")
return f"I understand. Could you tell me more about that?"
@spaces.GPU
def generate_feedback(self, prompt: str) -> str:
"""Generate detailed feedback using LLM"""
try:
self.load_llm()
# Format feedback prompt for Llama
messages = [
{
"role": "system",
"content": "You are an expert communication coach specializing in sales and professional communication. Provide specific, actionable feedback."
},
{
"role": "user",
"content": prompt
}
]
inputs = self.llm_tokenizer.apply_chat_template(
messages,
return_tensors="pt",
add_generation_prompt=True
).to(self.device)
outputs = self.llm_model.generate(
inputs,
max_new_tokens=500,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=self.llm_tokenizer.eos_token_id
)
feedback = self.llm_tokenizer.decode(
outputs[0][inputs.shape[1]:],
skip_special_tokens=True
)
return feedback.strip()
except Exception as e:
print(f"Error in feedback generation: {e}")
return "Unable to generate feedback at this time."