""" Refinement Model Module Load and run Llama 3.3 8B-Instruct for response refinement. Polishes CEO responses for grammar, clarity, and professional formatting. Example usage: model = RefinementModel.from_hub("meta-llama/Llama-3.3-8B-Instruct") refined = model.refine("Draft CEO response...") """ import os from pathlib import Path from typing import Iterator, Optional from loguru import logger try: import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer, ) INFERENCE_AVAILABLE = True except ImportError: INFERENCE_AVAILABLE = False logger.warning("Inference dependencies not available") from .prompt_templates import ( REFINEMENT_MODEL_SYSTEM_PROMPT, get_refinement_prompt, format_refinement_request, ) class RefinementModel: """ Refinement Model for polishing CEO responses. Takes draft responses from the Voice Model and improves them for grammar, clarity, and professional formatting while preserving voice. Example: >>> model = RefinementModel.from_hub() >>> refined = model.refine("Draft response text...") >>> print(refined) """ # Default model for refinement DEFAULT_MODEL = "meta-llama/Llama-3.3-8B-Instruct" def __init__( self, model, tokenizer, system_prompt: Optional[str] = None, device: str = "auto", ): """ Initialize with loaded model and tokenizer. Args: model: Loaded HuggingFace model tokenizer: Loaded tokenizer system_prompt: Custom system prompt device: Device for inference """ self.model = model self.tokenizer = tokenizer self.system_prompt = system_prompt or REFINEMENT_MODEL_SYSTEM_PROMPT self.device = device # Ensure padding token if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token @classmethod def from_hub( cls, model_id: Optional[str] = None, load_in_4bit: bool = True, load_in_8bit: bool = False, torch_dtype: str = "bfloat16", device_map: str = "auto", system_prompt: Optional[str] = None, token: Optional[str] = None, ) -> "RefinementModel": """ Load refinement model from Hugging Face Hub. Args: model_id: Model ID (defaults to Llama 3.3 8B) load_in_4bit: Use 4-bit quantization (recommended) load_in_8bit: Use 8-bit quantization torch_dtype: Torch dtype device_map: Device mapping system_prompt: Custom system prompt token: HF token Returns: RefinementModel instance """ if not INFERENCE_AVAILABLE: raise ImportError( "Inference dependencies not available. Install with:\n" "pip install torch transformers bitsandbytes" ) model_id = model_id or cls.DEFAULT_MODEL token = token or os.environ.get("HF_TOKEN") # Get torch dtype dtype_map = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, } dtype = dtype_map.get(torch_dtype, torch.bfloat16) # Quantization config quantization_config = None if load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=dtype, bnb_4bit_use_double_quant=True, ) elif load_in_8bit: quantization_config = BitsAndBytesConfig(load_in_8bit=True) logger.info(f"Loading refinement model: {model_id}") # Load model model = AutoModelForCausalLM.from_pretrained( model_id, quantization_config=quantization_config, device_map=device_map, torch_dtype=dtype, trust_remote_code=True, token=token, ) # Load tokenizer tokenizer = AutoTokenizer.from_pretrained( model_id, trust_remote_code=True, token=token, ) logger.info("Refinement model loaded successfully") return cls(model, tokenizer, system_prompt, device_map) def refine( self, draft_response: str, max_new_tokens: int = 1024, temperature: float = 0.3, top_p: float = 0.9, ) -> str: """ Refine a draft response. Args: draft_response: Draft CEO response to refine max_new_tokens: Maximum tokens to generate temperature: Sampling temperature (lower = more conservative) top_p: Top-p sampling Returns: Refined response text """ # Build messages messages = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": format_refinement_request(draft_response)}, ] # Format with chat template prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) # Tokenize inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=4096 - max_new_tokens, ).to(self.model.device) # Generate with torch.no_grad(): outputs = self.model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=temperature > 0, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, ) # Decode response only refined = self.tokenizer.decode( outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True, ) return self._clean_response(refined) def refine_stream( self, draft_response: str, max_new_tokens: int = 1024, temperature: float = 0.3, top_p: float = 0.9, ) -> Iterator[str]: """ Refine with streaming output. Args: draft_response: Draft to refine max_new_tokens: Maximum tokens temperature: Sampling temperature top_p: Top-p sampling Yields: Token strings as generated """ from threading import Thread # Build messages messages = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": format_refinement_request(draft_response)}, ] # Format prompt prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) # Tokenize inputs = self.tokenizer( prompt, return_tensors="pt", truncation=True, max_length=4096 - max_new_tokens, ).to(self.model.device) # Create streamer streamer = TextIteratorStreamer( self.tokenizer, skip_prompt=True, skip_special_tokens=True, ) # Generation kwargs generation_kwargs = dict( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=temperature > 0, streamer=streamer, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=self.tokenizer.eos_token_id, ) # Run in thread thread = Thread(target=self.model.generate, kwargs=generation_kwargs) thread.start() # Yield tokens for token in streamer: yield token thread.join() def _clean_response(self, response: str) -> str: """Clean up the refined response.""" response = response.strip() # Remove common unwanted prefixes unwanted_prefixes = [ "Here is the refined response:", "Here's the refined version:", "Refined response:", "---", ] for prefix in unwanted_prefixes: if response.startswith(prefix): response = response[len(prefix):].strip() # Remove trailing artifacts if response.endswith("---"): response = response[:-3].strip() return response def should_refine(self, draft_response: str, min_length: int = 50) -> bool: """ Determine if a response should be refined. Very short or simple responses might not need refinement. Args: draft_response: Draft to evaluate min_length: Minimum character length to warrant refinement Returns: Whether refinement is recommended """ if len(draft_response) < min_length: return False # Check for obvious issues that need refinement obvious_issues = [ " ", # Double spaces "\n\n\n", # Excessive newlines ] for issue in obvious_issues: if issue in draft_response: return True # Default: refine if above minimum length return True def update_system_prompt(self, new_prompt: str) -> None: """Update the system prompt.""" self.system_prompt = new_prompt logger.info("Refinement system prompt updated") def main(): """CLI entry point for testing the refinement model.""" import argparse parser = argparse.ArgumentParser( description="Test the refinement model", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python refinement_model.py --input "Draft text to refine..." python refinement_model.py --input-file draft.txt --output refined.txt """, ) parser.add_argument("--model", help="Model ID (default: Llama 3.3 8B)") parser.add_argument("--input", help="Text to refine") parser.add_argument("--input-file", help="File containing text to refine") parser.add_argument("--output", help="Output file for refined text") parser.add_argument("--no-4bit", action="store_true", help="Disable 4-bit") parser.add_argument("--temperature", type=float, default=0.3) parser.add_argument("--stream", action="store_true", help="Stream output") args = parser.parse_args() # Get input text if args.input: draft = args.input elif args.input_file: with open(args.input_file, "r") as f: draft = f.read() else: print("Error: Provide --input or --input-file") return 1 # Load model print(f"Loading refinement model...") model = RefinementModel.from_hub( model_id=args.model, load_in_4bit=not args.no_4bit, ) # Refine print("\nOriginal:") print("-" * 50) print(draft[:500] + "..." if len(draft) > 500 else draft) print("\nRefining...") print("-" * 50) if args.stream: refined_parts = [] for token in model.refine_stream(draft, temperature=args.temperature): print(token, end="", flush=True) refined_parts.append(token) refined = "".join(refined_parts) print() else: refined = model.refine(draft, temperature=args.temperature) print(refined) # Save if output specified if args.output: with open(args.output, "w") as f: f.write(refined) print(f"\nSaved to: {args.output}") return 0 if __name__ == "__main__": exit(main())