ai_exec / src /inference /voice_model.py
Chaitanya-aitf's picture
Upload 38 files
45ee481 verified
"""
Voice Model Module
Load and run the fine-tuned Qwen3 voice model for CEO-style response generation.
Optimized for Hugging Face Spaces GPU instances.
Example usage:
model = VoiceModel.from_hub("username/ceo-voice-model")
response = model.generate("What is your vision for AI?")
"""
import os
from pathlib import Path
from typing import Iterator, Optional
from loguru import logger
try:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TextIteratorStreamer,
)
from peft import PeftModel
INFERENCE_AVAILABLE = True
except ImportError:
INFERENCE_AVAILABLE = False
logger.warning("Inference dependencies not available")
from .prompt_templates import VOICE_MODEL_SYSTEM_PROMPT, get_voice_prompt
class VoiceModel:
"""
CEO Voice Model for generating authentic responses.
Loads a fine-tuned Qwen3 model with LoRA adapter and generates
responses in the CEO's communication style.
Example:
>>> model = VoiceModel.from_hub("username/ceo-voice-model")
>>> response = model.generate("What's your take on AI regulation?")
>>> print(response)
"""
def __init__(
self,
model,
tokenizer,
system_prompt: Optional[str] = None,
device: str = "auto",
):
"""
Initialize with loaded model and tokenizer.
Args:
model: Loaded HuggingFace model
tokenizer: Loaded tokenizer
system_prompt: Custom system prompt (uses default if None)
device: Device for inference
"""
self.model = model
self.tokenizer = tokenizer
self.system_prompt = system_prompt or VOICE_MODEL_SYSTEM_PROMPT
self.device = device
# Ensure padding token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
@classmethod
def from_hub(
cls,
model_id: str,
adapter_id: Optional[str] = None,
load_in_4bit: bool = True,
load_in_8bit: bool = False,
torch_dtype: str = "bfloat16",
device_map: str = "auto",
system_prompt: Optional[str] = None,
token: Optional[str] = None,
) -> "VoiceModel":
"""
Load voice model from Hugging Face Hub.
Args:
model_id: Base model or merged model ID
adapter_id: Optional adapter ID (if separate from base)
load_in_4bit: Use 4-bit quantization
load_in_8bit: Use 8-bit quantization
torch_dtype: Torch dtype
device_map: Device mapping
system_prompt: Custom system prompt
token: HF token
Returns:
VoiceModel instance
"""
if not INFERENCE_AVAILABLE:
raise ImportError(
"Inference dependencies not available. Install with:\n"
"pip install torch transformers peft bitsandbytes"
)
token = token or os.environ.get("HF_TOKEN")
# Get torch dtype
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
dtype = dtype_map.get(torch_dtype, torch.bfloat16)
# Quantization config
quantization_config = None
if load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dtype,
bnb_4bit_use_double_quant=True,
)
elif load_in_8bit:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
logger.info(f"Loading model: {model_id}")
# Load base model
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=dtype,
trust_remote_code=True,
token=token,
)
# Load adapter if specified
if adapter_id:
logger.info(f"Loading adapter: {adapter_id}")
model = PeftModel.from_pretrained(model, adapter_id, token=token)
# Load tokenizer
tokenizer_id = adapter_id or model_id
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_id,
trust_remote_code=True,
token=token,
)
logger.info("Model loaded successfully")
return cls(model, tokenizer, system_prompt, device_map)
@classmethod
def from_local(
cls,
model_path: str | Path,
adapter_path: Optional[str | Path] = None,
load_in_4bit: bool = True,
torch_dtype: str = "bfloat16",
system_prompt: Optional[str] = None,
) -> "VoiceModel":
"""
Load voice model from local path.
Args:
model_path: Path to model
adapter_path: Optional path to adapter
load_in_4bit: Use 4-bit quantization
torch_dtype: Torch dtype
system_prompt: Custom system prompt
Returns:
VoiceModel instance
"""
return cls.from_hub(
model_id=str(model_path),
adapter_id=str(adapter_path) if adapter_path else None,
load_in_4bit=load_in_4bit,
torch_dtype=torch_dtype,
system_prompt=system_prompt,
)
def generate(
self,
user_message: str,
conversation_history: Optional[list[dict]] = None,
max_new_tokens: int = 1024,
temperature: float = 0.7,
top_p: float = 0.9,
top_k: int = 50,
do_sample: bool = True,
repetition_penalty: float = 1.1,
) -> str:
"""
Generate a response to the user message.
Args:
user_message: User's input message
conversation_history: Optional list of prior messages
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Top-p sampling
top_k: Top-k sampling
do_sample: Whether to sample
repetition_penalty: Repetition penalty
Returns:
Generated response text
"""
# Build messages
messages = [{"role": "system", "content": self.system_prompt}]
# Add conversation history
if conversation_history:
for msg in conversation_history:
messages.append(msg)
# Add current message
messages.append({"role": "user", "content": user_message})
# Format with chat template
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Tokenize
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048 - max_new_tokens,
).to(self.model.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=do_sample,
repetition_penalty=repetition_penalty,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Decode response only (skip input)
response = self.tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
)
return response.strip()
def generate_stream(
self,
user_message: str,
conversation_history: Optional[list[dict]] = None,
max_new_tokens: int = 1024,
temperature: float = 0.7,
top_p: float = 0.9,
**kwargs,
) -> Iterator[str]:
"""
Generate a streaming response.
Args:
user_message: User's input message
conversation_history: Optional prior messages
max_new_tokens: Maximum tokens
temperature: Sampling temperature
top_p: Top-p sampling
**kwargs: Additional generation kwargs
Yields:
Token strings as they're generated
"""
from threading import Thread
# Build messages
messages = [{"role": "system", "content": self.system_prompt}]
if conversation_history:
messages.extend(conversation_history)
messages.append({"role": "user", "content": user_message})
# Format prompt
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Tokenize
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=2048 - max_new_tokens,
).to(self.model.device)
# Create streamer
streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
# Generation kwargs
generation_kwargs = dict(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
streamer=streamer,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
**kwargs,
)
# Run generation in thread
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
# Yield tokens
for token in streamer:
yield token
thread.join()
def update_system_prompt(self, new_prompt: str) -> None:
"""Update the system prompt."""
self.system_prompt = new_prompt
logger.info("System prompt updated")
def get_system_prompt(self) -> str:
"""Get current system prompt."""
return self.system_prompt
def main():
"""CLI entry point for testing the voice model."""
import argparse
parser = argparse.ArgumentParser(
description="Test the CEO voice model",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python voice_model.py --model username/ceo-voice-model --prompt "What is AI?"
python voice_model.py --model ./local_model --prompt "Your vision?"
""",
)
parser.add_argument("--model", required=True, help="Model ID or path")
parser.add_argument("--adapter", help="Adapter ID or path")
parser.add_argument("--prompt", required=True, help="User prompt")
parser.add_argument("--no-4bit", action="store_true", help="Disable 4-bit")
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--max-tokens", type=int, default=512)
parser.add_argument("--stream", action="store_true", help="Stream output")
args = parser.parse_args()
# Load model
print(f"Loading model: {args.model}")
model = VoiceModel.from_hub(
model_id=args.model,
adapter_id=args.adapter,
load_in_4bit=not args.no_4bit,
)
# Generate
print(f"\nPrompt: {args.prompt}\n")
print("-" * 50)
if args.stream:
for token in model.generate_stream(
args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
):
print(token, end="", flush=True)
print()
else:
response = model.generate(
args.prompt,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
)
print(response)
if __name__ == "__main__":
main()