Spaces:
Paused
Paused
| """ | |
| Refinement Model Module | |
| Load and run Llama 3.3 8B-Instruct for response refinement. | |
| Polishes CEO responses for grammar, clarity, and professional formatting. | |
| Example usage: | |
| model = RefinementModel.from_hub("meta-llama/Llama-3.3-8B-Instruct") | |
| refined = model.refine("Draft CEO response...") | |
| """ | |
| import os | |
| from pathlib import Path | |
| from typing import Iterator, Optional | |
| from loguru import logger | |
| try: | |
| import torch | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| BitsAndBytesConfig, | |
| TextIteratorStreamer, | |
| ) | |
| INFERENCE_AVAILABLE = True | |
| except ImportError: | |
| INFERENCE_AVAILABLE = False | |
| logger.warning("Inference dependencies not available") | |
| from .prompt_templates import ( | |
| REFINEMENT_MODEL_SYSTEM_PROMPT, | |
| get_refinement_prompt, | |
| format_refinement_request, | |
| ) | |
| class RefinementModel: | |
| """ | |
| Refinement Model for polishing CEO responses. | |
| Takes draft responses from the Voice Model and improves them for | |
| grammar, clarity, and professional formatting while preserving voice. | |
| Example: | |
| >>> model = RefinementModel.from_hub() | |
| >>> refined = model.refine("Draft response text...") | |
| >>> print(refined) | |
| """ | |
| # Default model for refinement | |
| DEFAULT_MODEL = "meta-llama/Llama-3.3-8B-Instruct" | |
| def __init__( | |
| self, | |
| model, | |
| tokenizer, | |
| system_prompt: Optional[str] = None, | |
| device: str = "auto", | |
| ): | |
| """ | |
| Initialize with loaded model and tokenizer. | |
| Args: | |
| model: Loaded HuggingFace model | |
| tokenizer: Loaded tokenizer | |
| system_prompt: Custom system prompt | |
| device: Device for inference | |
| """ | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.system_prompt = system_prompt or REFINEMENT_MODEL_SYSTEM_PROMPT | |
| self.device = device | |
| # Ensure padding token | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| def from_hub( | |
| cls, | |
| model_id: Optional[str] = None, | |
| load_in_4bit: bool = True, | |
| load_in_8bit: bool = False, | |
| torch_dtype: str = "bfloat16", | |
| device_map: str = "auto", | |
| system_prompt: Optional[str] = None, | |
| token: Optional[str] = None, | |
| ) -> "RefinementModel": | |
| """ | |
| Load refinement model from Hugging Face Hub. | |
| Args: | |
| model_id: Model ID (defaults to Llama 3.3 8B) | |
| load_in_4bit: Use 4-bit quantization (recommended) | |
| load_in_8bit: Use 8-bit quantization | |
| torch_dtype: Torch dtype | |
| device_map: Device mapping | |
| system_prompt: Custom system prompt | |
| token: HF token | |
| Returns: | |
| RefinementModel instance | |
| """ | |
| if not INFERENCE_AVAILABLE: | |
| raise ImportError( | |
| "Inference dependencies not available. Install with:\n" | |
| "pip install torch transformers bitsandbytes" | |
| ) | |
| model_id = model_id or cls.DEFAULT_MODEL | |
| token = token or os.environ.get("HF_TOKEN") | |
| # Get torch dtype | |
| dtype_map = { | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "float32": torch.float32, | |
| } | |
| dtype = dtype_map.get(torch_dtype, torch.bfloat16) | |
| # Quantization config | |
| quantization_config = None | |
| if load_in_4bit: | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=dtype, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| elif load_in_8bit: | |
| quantization_config = BitsAndBytesConfig(load_in_8bit=True) | |
| logger.info(f"Loading refinement model: {model_id}") | |
| # Load model | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| quantization_config=quantization_config, | |
| device_map=device_map, | |
| torch_dtype=dtype, | |
| trust_remote_code=True, | |
| token=token, | |
| ) | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| token=token, | |
| ) | |
| logger.info("Refinement model loaded successfully") | |
| return cls(model, tokenizer, system_prompt, device_map) | |
| def refine( | |
| self, | |
| draft_response: str, | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.3, | |
| top_p: float = 0.9, | |
| ) -> str: | |
| """ | |
| Refine a draft response. | |
| Args: | |
| draft_response: Draft CEO response to refine | |
| max_new_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature (lower = more conservative) | |
| top_p: Top-p sampling | |
| Returns: | |
| Refined response text | |
| """ | |
| # Build messages | |
| messages = [ | |
| {"role": "system", "content": self.system_prompt}, | |
| {"role": "user", "content": format_refinement_request(draft_response)}, | |
| ] | |
| # Format with chat template | |
| prompt = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=4096 - max_new_tokens, | |
| ).to(self.model.device) | |
| # Generate | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=temperature > 0, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| # Decode response only | |
| refined = self.tokenizer.decode( | |
| outputs[0][inputs["input_ids"].shape[1]:], | |
| skip_special_tokens=True, | |
| ) | |
| return self._clean_response(refined) | |
| def refine_stream( | |
| self, | |
| draft_response: str, | |
| max_new_tokens: int = 1024, | |
| temperature: float = 0.3, | |
| top_p: float = 0.9, | |
| ) -> Iterator[str]: | |
| """ | |
| Refine with streaming output. | |
| Args: | |
| draft_response: Draft to refine | |
| max_new_tokens: Maximum tokens | |
| temperature: Sampling temperature | |
| top_p: Top-p sampling | |
| Yields: | |
| Token strings as generated | |
| """ | |
| from threading import Thread | |
| # Build messages | |
| messages = [ | |
| {"role": "system", "content": self.system_prompt}, | |
| {"role": "user", "content": format_refinement_request(draft_response)}, | |
| ] | |
| # Format prompt | |
| prompt = self.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| # Tokenize | |
| inputs = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=4096 - max_new_tokens, | |
| ).to(self.model.device) | |
| # Create streamer | |
| streamer = TextIteratorStreamer( | |
| self.tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| # Generation kwargs | |
| generation_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=temperature > 0, | |
| streamer=streamer, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| # Run in thread | |
| thread = Thread(target=self.model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Yield tokens | |
| for token in streamer: | |
| yield token | |
| thread.join() | |
| def _clean_response(self, response: str) -> str: | |
| """Clean up the refined response.""" | |
| response = response.strip() | |
| # Remove common unwanted prefixes | |
| unwanted_prefixes = [ | |
| "Here is the refined response:", | |
| "Here's the refined version:", | |
| "Refined response:", | |
| "---", | |
| ] | |
| for prefix in unwanted_prefixes: | |
| if response.startswith(prefix): | |
| response = response[len(prefix):].strip() | |
| # Remove trailing artifacts | |
| if response.endswith("---"): | |
| response = response[:-3].strip() | |
| return response | |
| def should_refine(self, draft_response: str, min_length: int = 50) -> bool: | |
| """ | |
| Determine if a response should be refined. | |
| Very short or simple responses might not need refinement. | |
| Args: | |
| draft_response: Draft to evaluate | |
| min_length: Minimum character length to warrant refinement | |
| Returns: | |
| Whether refinement is recommended | |
| """ | |
| if len(draft_response) < min_length: | |
| return False | |
| # Check for obvious issues that need refinement | |
| obvious_issues = [ | |
| " ", # Double spaces | |
| "\n\n\n", # Excessive newlines | |
| ] | |
| for issue in obvious_issues: | |
| if issue in draft_response: | |
| return True | |
| # Default: refine if above minimum length | |
| return True | |
| def update_system_prompt(self, new_prompt: str) -> None: | |
| """Update the system prompt.""" | |
| self.system_prompt = new_prompt | |
| logger.info("Refinement system prompt updated") | |
| def main(): | |
| """CLI entry point for testing the refinement model.""" | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description="Test the refinement model", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| python refinement_model.py --input "Draft text to refine..." | |
| python refinement_model.py --input-file draft.txt --output refined.txt | |
| """, | |
| ) | |
| parser.add_argument("--model", help="Model ID (default: Llama 3.3 8B)") | |
| parser.add_argument("--input", help="Text to refine") | |
| parser.add_argument("--input-file", help="File containing text to refine") | |
| parser.add_argument("--output", help="Output file for refined text") | |
| parser.add_argument("--no-4bit", action="store_true", help="Disable 4-bit") | |
| parser.add_argument("--temperature", type=float, default=0.3) | |
| parser.add_argument("--stream", action="store_true", help="Stream output") | |
| args = parser.parse_args() | |
| # Get input text | |
| if args.input: | |
| draft = args.input | |
| elif args.input_file: | |
| with open(args.input_file, "r") as f: | |
| draft = f.read() | |
| else: | |
| print("Error: Provide --input or --input-file") | |
| return 1 | |
| # Load model | |
| print(f"Loading refinement model...") | |
| model = RefinementModel.from_hub( | |
| model_id=args.model, | |
| load_in_4bit=not args.no_4bit, | |
| ) | |
| # Refine | |
| print("\nOriginal:") | |
| print("-" * 50) | |
| print(draft[:500] + "..." if len(draft) > 500 else draft) | |
| print("\nRefining...") | |
| print("-" * 50) | |
| if args.stream: | |
| refined_parts = [] | |
| for token in model.refine_stream(draft, temperature=args.temperature): | |
| print(token, end="", flush=True) | |
| refined_parts.append(token) | |
| refined = "".join(refined_parts) | |
| print() | |
| else: | |
| refined = model.refine(draft, temperature=args.temperature) | |
| print(refined) | |
| # Save if output specified | |
| if args.output: | |
| with open(args.output, "w") as f: | |
| f.write(refined) | |
| print(f"\nSaved to: {args.output}") | |
| return 0 | |
| if __name__ == "__main__": | |
| exit(main()) | |