ai_exec / src /inference /dual_llm_pipeline.py
Chaitanya-aitf's picture
Upload 38 files
45ee481 verified
"""
Dual LLM Pipeline Module
Orchestrates the Voice Model and Refinement Model for complete
CEO-style response generation with quality polish.
Example usage:
pipeline = DualLLMPipeline.from_hub(
voice_model_id="username/ceo-voice-model",
refinement_model_id="meta-llama/Llama-3.3-8B-Instruct",
)
response = pipeline.generate("What is your vision for AI?")
"""
import hashlib
import time
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Iterator, Optional
from loguru import logger
from .voice_model import VoiceModel
from .refinement_model import RefinementModel
@dataclass
class PipelineResponse:
"""Response from the dual LLM pipeline."""
final_response: str
draft_response: str
was_refined: bool
voice_model_time: float
refinement_time: float
total_time: float
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
"final_response": self.final_response,
"draft_response": self.draft_response,
"was_refined": self.was_refined,
"voice_model_time": self.voice_model_time,
"refinement_time": self.refinement_time,
"total_time": self.total_time,
"metadata": self.metadata,
}
class LRUCache:
"""Simple LRU cache for response caching."""
def __init__(self, max_size: int = 100):
self.cache = OrderedDict()
self.max_size = max_size
def get(self, key: str) -> Optional[str]:
if key in self.cache:
self.cache.move_to_end(key)
return self.cache[key]
return None
def set(self, key: str, value: str) -> None:
if key in self.cache:
self.cache.move_to_end(key)
else:
if len(self.cache) >= self.max_size:
self.cache.popitem(last=False)
self.cache[key] = value
def clear(self) -> None:
self.cache.clear()
class DualLLMPipeline:
"""
Dual LLM Pipeline for CEO-style responses.
Orchestrates:
1. Voice Model: Generates authentic CEO-style draft
2. Refinement Model: Polishes for grammar, clarity, professionalism
Example:
>>> pipeline = DualLLMPipeline.from_hub("username/model")
>>> response = pipeline.generate("Your thoughts on AI?")
>>> print(response.final_response)
"""
def __init__(
self,
voice_model: VoiceModel,
refinement_model: Optional[RefinementModel] = None,
enable_cache: bool = True,
cache_size: int = 100,
min_length_for_refinement: int = 50,
):
"""
Initialize the pipeline.
Args:
voice_model: Voice model instance
refinement_model: Refinement model instance (optional)
enable_cache: Whether to cache responses
cache_size: Maximum cache entries
min_length_for_refinement: Minimum response length for refinement
"""
self.voice_model = voice_model
self.refinement_model = refinement_model
self.enable_cache = enable_cache
self.min_length_for_refinement = min_length_for_refinement
# Conversation history
self.conversation_history: list[dict] = []
self.max_history_turns: int = 5
# Response cache
self.cache = LRUCache(cache_size) if enable_cache else None
# Refinement control
self._refinement_enabled = refinement_model is not None
# Stats tracking
self._total_requests = 0
self._cache_hits = 0
@classmethod
def from_hub(
cls,
voice_model_id: str,
voice_adapter_id: Optional[str] = None,
refinement_model_id: Optional[str] = None,
load_in_4bit: bool = True,
enable_refinement: bool = True,
enable_cache: bool = True,
token: Optional[str] = None,
) -> "DualLLMPipeline":
"""
Load pipeline from Hugging Face Hub.
Args:
voice_model_id: Voice model or adapter ID
voice_adapter_id: Separate adapter ID (if applicable)
refinement_model_id: Refinement model ID (default: Llama 3.3 8B)
load_in_4bit: Use 4-bit quantization
enable_refinement: Whether to enable refinement stage
enable_cache: Whether to cache responses
token: HF token
Returns:
DualLLMPipeline instance
"""
logger.info("Loading dual LLM pipeline...")
# Load voice model
logger.info(f"Loading voice model: {voice_model_id}")
voice_model = VoiceModel.from_hub(
model_id=voice_model_id,
adapter_id=voice_adapter_id,
load_in_4bit=load_in_4bit,
token=token,
)
# Load refinement model (optional)
refinement_model = None
if enable_refinement:
refinement_id = refinement_model_id or RefinementModel.DEFAULT_MODEL
logger.info(f"Loading refinement model: {refinement_id}")
refinement_model = RefinementModel.from_hub(
model_id=refinement_id,
load_in_4bit=load_in_4bit,
token=token,
)
logger.info("Pipeline loaded successfully")
return cls(voice_model, refinement_model, enable_cache)
def generate(
self,
user_message: str,
skip_refinement: bool = False,
use_cache: bool = True,
voice_temperature: float = 0.7,
refinement_temperature: float = 0.3,
max_new_tokens: int = 1024,
) -> PipelineResponse:
"""
Generate a complete response through the pipeline.
Args:
user_message: User's input message
skip_refinement: Skip refinement stage
use_cache: Use response cache
voice_temperature: Temperature for voice model
refinement_temperature: Temperature for refinement
max_new_tokens: Maximum tokens to generate
Returns:
PipelineResponse with final and draft responses
"""
start_time = time.time()
self._total_requests += 1
# Check cache
cache_key = self._get_cache_key(user_message)
if use_cache and self.cache:
cached = self.cache.get(cache_key)
if cached:
self._cache_hits += 1
logger.debug(f"Cache hit for: {user_message[:50]}...")
return PipelineResponse(
final_response=cached,
draft_response=cached,
was_refined=False,
voice_model_time=0,
refinement_time=0,
total_time=0,
metadata={"cache_hit": True},
)
# Stage 1: Voice Model
voice_start = time.time()
draft_response = self.voice_model.generate(
user_message=user_message,
conversation_history=self.conversation_history,
max_new_tokens=max_new_tokens,
temperature=voice_temperature,
)
voice_time = time.time() - voice_start
logger.debug(f"Voice model generated in {voice_time:.2f}s")
# Stage 2: Refinement (optional)
final_response = draft_response
refinement_time = 0
was_refined = False
should_refine = (
self.refinement_model is not None
and self._refinement_enabled
and not skip_refinement
and len(draft_response) >= self.min_length_for_refinement
)
if should_refine:
refine_start = time.time()
final_response = self.refinement_model.refine(
draft_response=draft_response,
max_new_tokens=max_new_tokens,
temperature=refinement_temperature,
)
refinement_time = time.time() - refine_start
was_refined = True
logger.debug(f"Refinement completed in {refinement_time:.2f}s")
total_time = time.time() - start_time
# Update conversation history
self._update_history(user_message, final_response)
# Cache the response
if self.cache and use_cache:
self.cache.set(cache_key, final_response)
return PipelineResponse(
final_response=final_response,
draft_response=draft_response,
was_refined=was_refined,
voice_model_time=voice_time,
refinement_time=refinement_time,
total_time=total_time,
)
def generate_stream(
self,
user_message: str,
skip_refinement: bool = False,
voice_temperature: float = 0.7,
max_new_tokens: int = 1024,
) -> Iterator[str]:
"""
Generate a streaming response (voice model only, no refinement).
Note: Refinement is not supported in streaming mode because
we need the complete draft to refine it.
Args:
user_message: User's input
skip_refinement: Always True for streaming
voice_temperature: Temperature
max_new_tokens: Maximum tokens
Yields:
Token strings as generated
"""
# Stream from voice model
for token in self.voice_model.generate_stream(
user_message=user_message,
conversation_history=self.conversation_history,
max_new_tokens=max_new_tokens,
temperature=voice_temperature,
):
yield token
def generate_ab_test(
self,
user_message: str,
voice_temperature: float = 0.7,
max_new_tokens: int = 1024,
) -> dict:
"""
Generate both refined and unrefined responses for comparison.
Args:
user_message: User's input
voice_temperature: Temperature
max_new_tokens: Maximum tokens
Returns:
Dictionary with 'draft', 'refined', and timing info
"""
start_time = time.time()
# Generate draft
voice_start = time.time()
draft = self.voice_model.generate(
user_message=user_message,
conversation_history=self.conversation_history,
max_new_tokens=max_new_tokens,
temperature=voice_temperature,
)
voice_time = time.time() - voice_start
# Generate refined (if available)
refined = draft
refinement_time = 0
if self.refinement_model:
refine_start = time.time()
refined = self.refinement_model.refine(draft)
refinement_time = time.time() - refine_start
return {
"draft": draft,
"refined": refined,
"voice_time": voice_time,
"refinement_time": refinement_time,
"total_time": time.time() - start_time,
"different": draft != refined,
}
def _get_cache_key(self, user_message: str) -> str:
"""Generate a cache key from message only."""
# Cache based on message content alone for better hit rate
# Context-dependent responses are handled by the model at generation time
return hashlib.md5(user_message.encode()).hexdigest()
def _update_history(self, user_message: str, response: str) -> None:
"""Update conversation history."""
self.conversation_history.append({
"role": "user",
"content": user_message,
})
self.conversation_history.append({
"role": "assistant",
"content": response,
})
# Trim history
max_messages = self.max_history_turns * 2
if len(self.conversation_history) > max_messages:
self.conversation_history = self.conversation_history[-max_messages:]
def clear_history(self) -> None:
"""Clear conversation history."""
self.conversation_history = []
logger.info("Conversation history cleared")
def clear_cache(self) -> None:
"""Clear response cache."""
if self.cache:
self.cache.clear()
logger.info("Response cache cleared")
def get_stats(self) -> dict:
"""Get pipeline statistics."""
return {
"total_requests": self._total_requests,
"cache_hits": self._cache_hits,
"cache_hit_rate": (
self._cache_hits / self._total_requests
if self._total_requests > 0
else 0
),
"history_length": len(self.conversation_history),
"refinement_available": self.refinement_model is not None,
"refinement_enabled": self._refinement_enabled,
}
def set_max_history_turns(self, turns: int) -> None:
"""Set maximum conversation history turns."""
self.max_history_turns = turns
def enable_refinement(self, enable: bool = True) -> None:
"""Enable or disable refinement stage."""
if enable and self.refinement_model is None:
logger.warning("No refinement model loaded")
self._refinement_enabled = enable
def main():
"""CLI entry point for testing the pipeline."""
import argparse
parser = argparse.ArgumentParser(
description="Test the dual LLM pipeline",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python dual_llm_pipeline.py --voice-model username/model --prompt "Question?"
python dual_llm_pipeline.py --voice-model username/model --interactive
""",
)
parser.add_argument("--voice-model", required=True, help="Voice model ID")
parser.add_argument("--voice-adapter", help="Voice adapter ID")
parser.add_argument("--refinement-model", help="Refinement model ID")
parser.add_argument("--prompt", help="Single prompt to process")
parser.add_argument("--interactive", action="store_true", help="Interactive mode")
parser.add_argument("--no-refinement", action="store_true", help="Skip refinement")
parser.add_argument("--no-4bit", action="store_true", help="Disable 4-bit")
parser.add_argument("--ab-test", action="store_true", help="Show both versions")
args = parser.parse_args()
# Load pipeline
print("Loading pipeline...")
pipeline = DualLLMPipeline.from_hub(
voice_model_id=args.voice_model,
voice_adapter_id=args.voice_adapter,
refinement_model_id=args.refinement_model,
load_in_4bit=not args.no_4bit,
enable_refinement=not args.no_refinement,
)
if args.interactive:
# Interactive mode
print("\nAI Executive Chatbot")
print("Type 'quit' to exit, 'clear' to clear history\n")
while True:
try:
user_input = input("You: ").strip()
except (EOFError, KeyboardInterrupt):
break
if not user_input:
continue
if user_input.lower() == "quit":
break
if user_input.lower() == "clear":
pipeline.clear_history()
print("History cleared.\n")
continue
response = pipeline.generate(user_input, skip_refinement=args.no_refinement)
print(f"\nCEO: {response.final_response}")
print(f"[Time: {response.total_time:.2f}s, Refined: {response.was_refined}]\n")
elif args.prompt:
# Single prompt
if args.ab_test:
result = pipeline.generate_ab_test(args.prompt)
print("\n=== Draft Response ===")
print(result["draft"])
print(f"\n[Voice model time: {result['voice_time']:.2f}s]")
print("\n=== Refined Response ===")
print(result["refined"])
print(f"\n[Refinement time: {result['refinement_time']:.2f}s]")
print(f"[Total time: {result['total_time']:.2f}s]")
print(f"[Changed: {result['different']}]")
else:
response = pipeline.generate(args.prompt, skip_refinement=args.no_refinement)
print(f"\nResponse: {response.final_response}")
print(f"\n[Total time: {response.total_time:.2f}s]")
print(f"[Voice: {response.voice_model_time:.2f}s, Refine: {response.refinement_time:.2f}s]")
else:
print("Provide --prompt or --interactive")
return 1
return 0
if __name__ == "__main__":
exit(main())