ai_exec / src /inference /refinement_model.py
Chaitanya-aitf's picture
Upload 38 files
45ee481 verified
"""
Refinement Model Module
Load and run Llama 3.3 8B-Instruct for response refinement.
Polishes CEO responses for grammar, clarity, and professional formatting.
Example usage:
model = RefinementModel.from_hub("meta-llama/Llama-3.3-8B-Instruct")
refined = model.refine("Draft CEO response...")
"""
import os
from pathlib import Path
from typing import Iterator, Optional
from loguru import logger
try:
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
TextIteratorStreamer,
)
INFERENCE_AVAILABLE = True
except ImportError:
INFERENCE_AVAILABLE = False
logger.warning("Inference dependencies not available")
from .prompt_templates import (
REFINEMENT_MODEL_SYSTEM_PROMPT,
get_refinement_prompt,
format_refinement_request,
)
class RefinementModel:
"""
Refinement Model for polishing CEO responses.
Takes draft responses from the Voice Model and improves them for
grammar, clarity, and professional formatting while preserving voice.
Example:
>>> model = RefinementModel.from_hub()
>>> refined = model.refine("Draft response text...")
>>> print(refined)
"""
# Default model for refinement
DEFAULT_MODEL = "meta-llama/Llama-3.3-8B-Instruct"
def __init__(
self,
model,
tokenizer,
system_prompt: Optional[str] = None,
device: str = "auto",
):
"""
Initialize with loaded model and tokenizer.
Args:
model: Loaded HuggingFace model
tokenizer: Loaded tokenizer
system_prompt: Custom system prompt
device: Device for inference
"""
self.model = model
self.tokenizer = tokenizer
self.system_prompt = system_prompt or REFINEMENT_MODEL_SYSTEM_PROMPT
self.device = device
# Ensure padding token
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
@classmethod
def from_hub(
cls,
model_id: Optional[str] = None,
load_in_4bit: bool = True,
load_in_8bit: bool = False,
torch_dtype: str = "bfloat16",
device_map: str = "auto",
system_prompt: Optional[str] = None,
token: Optional[str] = None,
) -> "RefinementModel":
"""
Load refinement model from Hugging Face Hub.
Args:
model_id: Model ID (defaults to Llama 3.3 8B)
load_in_4bit: Use 4-bit quantization (recommended)
load_in_8bit: Use 8-bit quantization
torch_dtype: Torch dtype
device_map: Device mapping
system_prompt: Custom system prompt
token: HF token
Returns:
RefinementModel instance
"""
if not INFERENCE_AVAILABLE:
raise ImportError(
"Inference dependencies not available. Install with:\n"
"pip install torch transformers bitsandbytes"
)
model_id = model_id or cls.DEFAULT_MODEL
token = token or os.environ.get("HF_TOKEN")
# Get torch dtype
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
dtype = dtype_map.get(torch_dtype, torch.bfloat16)
# Quantization config
quantization_config = None
if load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dtype,
bnb_4bit_use_double_quant=True,
)
elif load_in_8bit:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
logger.info(f"Loading refinement model: {model_id}")
# Load model
model = AutoModelForCausalLM.from_pretrained(
model_id,
quantization_config=quantization_config,
device_map=device_map,
torch_dtype=dtype,
trust_remote_code=True,
token=token,
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
token=token,
)
logger.info("Refinement model loaded successfully")
return cls(model, tokenizer, system_prompt, device_map)
def refine(
self,
draft_response: str,
max_new_tokens: int = 1024,
temperature: float = 0.3,
top_p: float = 0.9,
) -> str:
"""
Refine a draft response.
Args:
draft_response: Draft CEO response to refine
max_new_tokens: Maximum tokens to generate
temperature: Sampling temperature (lower = more conservative)
top_p: Top-p sampling
Returns:
Refined response text
"""
# Build messages
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": format_refinement_request(draft_response)},
]
# Format with chat template
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Tokenize
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=4096 - max_new_tokens,
).to(self.model.device)
# Generate
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Decode response only
refined = self.tokenizer.decode(
outputs[0][inputs["input_ids"].shape[1]:],
skip_special_tokens=True,
)
return self._clean_response(refined)
def refine_stream(
self,
draft_response: str,
max_new_tokens: int = 1024,
temperature: float = 0.3,
top_p: float = 0.9,
) -> Iterator[str]:
"""
Refine with streaming output.
Args:
draft_response: Draft to refine
max_new_tokens: Maximum tokens
temperature: Sampling temperature
top_p: Top-p sampling
Yields:
Token strings as generated
"""
from threading import Thread
# Build messages
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": format_refinement_request(draft_response)},
]
# Format prompt
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
# Tokenize
inputs = self.tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=4096 - max_new_tokens,
).to(self.model.device)
# Create streamer
streamer = TextIteratorStreamer(
self.tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
# Generation kwargs
generation_kwargs = dict(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=temperature > 0,
streamer=streamer,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Run in thread
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
# Yield tokens
for token in streamer:
yield token
thread.join()
def _clean_response(self, response: str) -> str:
"""Clean up the refined response."""
response = response.strip()
# Remove common unwanted prefixes
unwanted_prefixes = [
"Here is the refined response:",
"Here's the refined version:",
"Refined response:",
"---",
]
for prefix in unwanted_prefixes:
if response.startswith(prefix):
response = response[len(prefix):].strip()
# Remove trailing artifacts
if response.endswith("---"):
response = response[:-3].strip()
return response
def should_refine(self, draft_response: str, min_length: int = 50) -> bool:
"""
Determine if a response should be refined.
Very short or simple responses might not need refinement.
Args:
draft_response: Draft to evaluate
min_length: Minimum character length to warrant refinement
Returns:
Whether refinement is recommended
"""
if len(draft_response) < min_length:
return False
# Check for obvious issues that need refinement
obvious_issues = [
" ", # Double spaces
"\n\n\n", # Excessive newlines
]
for issue in obvious_issues:
if issue in draft_response:
return True
# Default: refine if above minimum length
return True
def update_system_prompt(self, new_prompt: str) -> None:
"""Update the system prompt."""
self.system_prompt = new_prompt
logger.info("Refinement system prompt updated")
def main():
"""CLI entry point for testing the refinement model."""
import argparse
parser = argparse.ArgumentParser(
description="Test the refinement model",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python refinement_model.py --input "Draft text to refine..."
python refinement_model.py --input-file draft.txt --output refined.txt
""",
)
parser.add_argument("--model", help="Model ID (default: Llama 3.3 8B)")
parser.add_argument("--input", help="Text to refine")
parser.add_argument("--input-file", help="File containing text to refine")
parser.add_argument("--output", help="Output file for refined text")
parser.add_argument("--no-4bit", action="store_true", help="Disable 4-bit")
parser.add_argument("--temperature", type=float, default=0.3)
parser.add_argument("--stream", action="store_true", help="Stream output")
args = parser.parse_args()
# Get input text
if args.input:
draft = args.input
elif args.input_file:
with open(args.input_file, "r") as f:
draft = f.read()
else:
print("Error: Provide --input or --input-file")
return 1
# Load model
print(f"Loading refinement model...")
model = RefinementModel.from_hub(
model_id=args.model,
load_in_4bit=not args.no_4bit,
)
# Refine
print("\nOriginal:")
print("-" * 50)
print(draft[:500] + "..." if len(draft) > 500 else draft)
print("\nRefining...")
print("-" * 50)
if args.stream:
refined_parts = []
for token in model.refine_stream(draft, temperature=args.temperature):
print(token, end="", flush=True)
refined_parts.append(token)
refined = "".join(refined_parts)
print()
else:
refined = model.refine(draft, temperature=args.temperature)
print(refined)
# Save if output specified
if args.output:
with open(args.output, "w") as f:
f.write(refined)
print(f"\nSaved to: {args.output}")
return 0
if __name__ == "__main__":
exit(main())