""" Voice Model Module Load and run the fine-tuned Qwen3 voice model for CEO-style response generation. Optimized for Hugging Face Spaces GPU instances. Example usage: model = VoiceModel.from_hub("username/ceo-voice-model") response = model.generate("What is your vision for AI?") """ import os from pathlib import Path from typing import Iterator, Optional from loguru import logger try: import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, ) from peft import PeftModel INFERENCE_AVAILABLE = True except ImportError: INFERENCE_AVAILABLE = False logger.warning("Inference dependencies not available") from .prompt_templates import VOICE_MODEL_SYSTEM_PROMPT, get_voice_prompt class VoiceModel: """ CEO Voice Model for generating authentic responses. Loads a fine-tuned Qwen3 model with LoRA adapter and generates responses in the CEO's communication style. Example: >>> model = VoiceModel.from_hub("username/ceo-voice-model") >>> response = model.generate("What's your take on AI regulation?") >>> print(response) """ def __init__( self, model, tokenizer, system_prompt: Optional[str] = None, device: str = "auto", ): """ Initialize with loaded model and tokenizer. Args: model: Loaded HuggingFace model tokenizer: Loaded tokenizer system_prompt: Custom system prompt (uses default if None) device: Device for inference """ self.model = model self.tokenizer = tokenizer self.system_prompt = system_prompt or VOICE_MODEL_SYSTEM_PROMPT self.device = device # Ensure padding token if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token @classmethod def from_hub( cls, model_id: str, adapter_id: Optional[str] = None, load_in_4bit: bool = True, load_in_8bit: bool = False, torch_dtype: str = "bfloat16", device_map: str = "auto", system_prompt: Optional[str] = None, token: Optional[str] = None, ) -> "VoiceModel": """ Load voice model from Hugging Face Hub. Args: model_id: Base model or merged model ID adapter_id: Optional adapter ID (if separate from base) load_in_4bit: Use 4-bit quantization load_in_8bit: Use 8-bit quantization torch_dtype: Torch dtype device_map: Device mapping system_prompt: Custom system prompt token: HF token Returns: VoiceModel instance """ if not INFERENCE_AVAILABLE: raise ImportError( "Inference dependencies not available. Install with:\n" "pip install torch transformers peft bitsandbytes" ) token = token or os.environ.get("HF_TOKEN") # Get torch dtype dtype_map = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, } dtype = dtype_map.get(torch_dtype, torch.bfloat16) # Quantization config quantization_config = None if load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=dtype, bnb_4bit_use_double_quant=True, ) elif load_in_8bit: quantization_config = BitsAndBytesConfig(load_in_8bit=True) logger.info(f"Loading model: {model_id}") # Load base model model = AutoModelForCausalLM.from_pretrained( model_id, quantization_config=quantization_config, device_map=device_map, torch_dtype=dtype, trust_remote_code=True, token=token, ) # Load adapter if specified if adapter_id: logger.info(f"Loading adapter: {adapter_id}") model = PeftModel.from_pretrained(model, adapter_id, token=token) # Load tokenizer tokenizer_id = adapter_id or model_id tokenizer = AutoTokenizer.from_pretrained( tokenizer_id, trust_remote_code=True, token=token, ) logger.info("Model loaded successfully") return cls(model, tokenizer, system_prompt, device_map) @classmethod def from_local( cls, model_path: str | Path, adapter_path: Optional[str | Path] = None, load_in_4bit: bool = True, torch_dtype: str = "bfloat16", system_prompt: Optional[str] = None, ) -> "VoiceModel": """ Load voice model from local path. Args: model_path: Path to model adapter_path: Optional path to adapter load_in_4bit: Use 4-bit quantization torch_dtype: Torch dtype system_prompt: Custom system prompt Returns: VoiceModel instance """ return cls.from_hub( model_id=str(model_path), adapter_id=str(adapter_path) if adapter_path else None, load_in_4bit=load_in_4bit, torch_dtype=torch_dtype, system_prompt=system_prompt, ) def generate( self, user_message: str, conversation_history: Optional[list[dict]] = None, max_new_tokens: int = 1024, temperature: float = 0.7, top_p: float = 0.9, top_k: int = 50, do_sample: bool = True, repetition_penalty: float = 1.1, ) -> str: """ Generate a response to the user message. Args: user_message: User's input message conversation_history: Optional list of prior messages max_new_tokens: Maximum tokens to generate temperature: Sampling temperature top_p: Top-p sampling top_k: Top-k sampling do_sample: Whether to sample repetition_penalty: Repetition penalty Returns: Generated response text """ # Build messages messages = [{"role": "system", "content": self.system_prompt}] # Add conversation history if conversation_history: for msg in conversation_history: messages.append(msg) # Add current message messages.append({"role": "user", "content": user_message}) # Format with chat template prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) # Tokenize inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=2048 - max_new_tokens, ).to(self.model.device) # Generate with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=do_sample, repetition_penalty=repetition_penalty, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, ) # Decode response only (skip input) response = self.tokenizer.decode( outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True, ) return response.strip() def generate_stream( self, user_message: str, conversation_history: Optional[list[dict]] = None, max_new_tokens: int = 1024, temperature: float = 0.7, top_p: float = 0.9, **kwargs, ) -> Iterator[str]: """ Generate a streaming response. Args: user_message: User's input message conversation_history: Optional prior messages max_new_tokens: Maximum tokens temperature: Sampling temperature top_p: Top-p sampling **kwargs: Additional generation kwargs Yields: Token strings as they're generated """ from threading import Thread # Build messages messages = [{"role": "system", "content": self.system_prompt}] if conversation_history: messages.extend(conversation_history) messages.append({"role": "user", "content": user_message}) # Format prompt prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) # Tokenize inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=2048 - max_new_tokens, ).to(self.model.device) # Create streamer streamer = TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True, ) # Generation kwargs generation_kwargs = dict( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True, streamer=streamer, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, **kwargs, ) # Run generation in thread thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() # Yield tokens for token in streamer: yield token thread.join() def update_system_prompt(self, new_prompt: str) -> None: """Update the system prompt.""" self.system_prompt = new_prompt logger.info("System prompt updated") def get_system_prompt(self) -> str: """Get current system prompt.""" return self.system_prompt def main(): """CLI entry point for testing the voice model.""" import argparse parser = argparse.ArgumentParser( description="Test the CEO voice model", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python voice_model.py --model username/ceo-voice-model --prompt "What is AI?" python voice_model.py --model ./local_model --prompt "Your vision?" """, ) parser.add_argument("--model", required=True, help="Model ID or path") parser.add_argument("--adapter", help="Adapter ID or path") parser.add_argument("--prompt", required=True, help="User prompt") parser.add_argument("--no-4bit", action="store_true", help="Disable 4-bit") parser.add_argument("--temperature", type=float, default=0.7) parser.add_argument("--max-tokens", type=int, default=512) parser.add_argument("--stream", action="store_true", help="Stream output") args = parser.parse_args() # Load model print(f"Loading model: {args.model}") model = VoiceModel.from_hub( model_id=args.model, adapter_id=args.adapter, load_in_4bit=not args.no_4bit, ) # Generate print(f"\nPrompt: {args.prompt}\n") print("-" * 50) if args.stream: for token in model.generate_stream( args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, ): print(token, end="", flush=True) print() else: response = model.generate( args.prompt, max_new_tokens=args.max_tokens, temperature=args.temperature, ) print(response) if __name__ == "__main__": main()