Spaces:
Paused
Paused
| """ | |
| Dual LLM Pipeline Module | |
| Orchestrates the Voice Model and Refinement Model for complete | |
| CEO-style response generation with quality polish. | |
| Example usage: | |
| pipeline = DualLLMPipeline.from_hub( | |
| voice_model_id="username/ceo-voice-model", | |
| refinement_model_id="meta-llama/Llama-3.3-8B-Instruct", | |
| ) | |
| response = pipeline.generate("What is your vision for AI?") | |
| """ | |
| import hashlib | |
| import time | |
| from collections import OrderedDict | |
| from dataclasses import dataclass, field | |
| from typing import Iterator, Optional | |
| from loguru import logger | |
| from .voice_model import VoiceModel | |
| from .refinement_model import RefinementModel | |
| class PipelineResponse: | |
| """Response from the dual LLM pipeline.""" | |
| final_response: str | |
| draft_response: str | |
| was_refined: bool | |
| voice_model_time: float | |
| refinement_time: float | |
| total_time: float | |
| metadata: dict = field(default_factory=dict) | |
| def to_dict(self) -> dict: | |
| """Convert to dictionary.""" | |
| return { | |
| "final_response": self.final_response, | |
| "draft_response": self.draft_response, | |
| "was_refined": self.was_refined, | |
| "voice_model_time": self.voice_model_time, | |
| "refinement_time": self.refinement_time, | |
| "total_time": self.total_time, | |
| "metadata": self.metadata, | |
| } | |
| class LRUCache: | |
| """Simple LRU cache for response caching.""" | |
| def __init__(self, max_size: int = 100): | |
| self.cache = OrderedDict() | |
| self.max_size = max_size | |
| def get(self, key: str) -> Optional[str]: | |
| if key in self.cache: | |
| self.cache.move_to_end(key) | |
| return self.cache[key] | |
| return None | |
| def set(self, key: str, value: str) -> None: | |
| if key in self.cache: | |
| self.cache.move_to_end(key) | |
| else: | |
| if len(self.cache) >= self.max_size: | |
| self.cache.popitem(last=False) | |
| self.cache[key] = value | |
| def clear(self) -> None: | |
| self.cache.clear() | |
| class DualLLMPipeline: | |
| """ | |
| Dual LLM Pipeline for CEO-style responses. | |
| Orchestrates: | |
| 1. Voice Model: Generates authentic CEO-style draft | |
| 2. Refinement Model: Polishes for grammar, clarity, professionalism | |
| Example: | |
| >>> pipeline = DualLLMPipeline.from_hub("username/model") | |
| >>> response = pipeline.generate("Your thoughts on AI?") | |
| >>> print(response.final_response) | |
| """ | |
| def __init__( | |
| self, | |
| voice_model: VoiceModel, | |
| refinement_model: Optional[RefinementModel] = None, | |
| enable_cache: bool = True, | |
| cache_size: int = 100, | |
| min_length_for_refinement: int = 50, | |
| ): | |
| """ | |
| Initialize the pipeline. | |
| Args: | |
| voice_model: Voice model instance | |
| refinement_model: Refinement model instance (optional) | |
| enable_cache: Whether to cache responses | |
| cache_size: Maximum cache entries | |
| min_length_for_refinement: Minimum response length for refinement | |
| """ | |
| self.voice_model = voice_model | |
| self.refinement_model = refinement_model | |
| self.enable_cache = enable_cache | |
| self.min_length_for_refinement = min_length_for_refinement | |
| # Conversation history | |
| self.conversation_history: list[dict] = [] | |
| self.max_history_turns: int = 5 | |
| # Response cache | |
| self.cache = LRUCache(cache_size) if enable_cache else None | |
| # Refinement control | |
| self._refinement_enabled = refinement_model is not None | |
| # Stats tracking | |
| self._total_requests = 0 | |
| self._cache_hits = 0 | |
| def from_hub( | |
| cls, | |
| voice_model_id: str, | |
| voice_adapter_id: Optional[str] = None, | |
| refinement_model_id: Optional[str] = None, | |
| load_in_4bit: bool = True, | |
| enable_refinement: bool = True, | |
| enable_cache: bool = True, | |
| token: Optional[str] = None, | |
| ) -> "DualLLMPipeline": | |
| """ | |
| Load pipeline from Hugging Face Hub. | |
| Args: | |
| voice_model_id: Voice model or adapter ID | |
| voice_adapter_id: Separate adapter ID (if applicable) | |
| refinement_model_id: Refinement model ID (default: Llama 3.3 8B) | |
| load_in_4bit: Use 4-bit quantization | |
| enable_refinement: Whether to enable refinement stage | |
| enable_cache: Whether to cache responses | |
| token: HF token | |
| Returns: | |
| DualLLMPipeline instance | |
| """ | |
| logger.info("Loading dual LLM pipeline...") | |
| # Load voice model | |
| logger.info(f"Loading voice model: {voice_model_id}") | |
| voice_model = VoiceModel.from_hub( | |
| model_id=voice_model_id, | |
| adapter_id=voice_adapter_id, | |
| load_in_4bit=load_in_4bit, | |
| token=token, | |
| ) | |
| # Load refinement model (optional) | |
| refinement_model = None | |
| if enable_refinement: | |
| refinement_id = refinement_model_id or RefinementModel.DEFAULT_MODEL | |
| logger.info(f"Loading refinement model: {refinement_id}") | |
| refinement_model = RefinementModel.from_hub( | |
| model_id=refinement_id, | |
| load_in_4bit=load_in_4bit, | |
| token=token, | |
| ) | |
| logger.info("Pipeline loaded successfully") | |
| return cls(voice_model, refinement_model, enable_cache) | |
| def generate( | |
| self, | |
| user_message: str, | |
| skip_refinement: bool = False, | |
| use_cache: bool = True, | |
| voice_temperature: float = 0.7, | |
| refinement_temperature: float = 0.3, | |
| max_new_tokens: int = 1024, | |
| ) -> PipelineResponse: | |
| """ | |
| Generate a complete response through the pipeline. | |
| Args: | |
| user_message: User's input message | |
| skip_refinement: Skip refinement stage | |
| use_cache: Use response cache | |
| voice_temperature: Temperature for voice model | |
| refinement_temperature: Temperature for refinement | |
| max_new_tokens: Maximum tokens to generate | |
| Returns: | |
| PipelineResponse with final and draft responses | |
| """ | |
| start_time = time.time() | |
| self._total_requests += 1 | |
| # Check cache | |
| cache_key = self._get_cache_key(user_message) | |
| if use_cache and self.cache: | |
| cached = self.cache.get(cache_key) | |
| if cached: | |
| self._cache_hits += 1 | |
| logger.debug(f"Cache hit for: {user_message[:50]}...") | |
| return PipelineResponse( | |
| final_response=cached, | |
| draft_response=cached, | |
| was_refined=False, | |
| voice_model_time=0, | |
| refinement_time=0, | |
| total_time=0, | |
| metadata={"cache_hit": True}, | |
| ) | |
| # Stage 1: Voice Model | |
| voice_start = time.time() | |
| draft_response = self.voice_model.generate( | |
| user_message=user_message, | |
| conversation_history=self.conversation_history, | |
| max_new_tokens=max_new_tokens, | |
| temperature=voice_temperature, | |
| ) | |
| voice_time = time.time() - voice_start | |
| logger.debug(f"Voice model generated in {voice_time:.2f}s") | |
| # Stage 2: Refinement (optional) | |
| final_response = draft_response | |
| refinement_time = 0 | |
| was_refined = False | |
| should_refine = ( | |
| self.refinement_model is not None | |
| and self._refinement_enabled | |
| and not skip_refinement | |
| and len(draft_response) >= self.min_length_for_refinement | |
| ) | |
| if should_refine: | |
| refine_start = time.time() | |
| final_response = self.refinement_model.refine( | |
| draft_response=draft_response, | |
| max_new_tokens=max_new_tokens, | |
| temperature=refinement_temperature, | |
| ) | |
| refinement_time = time.time() - refine_start | |
| was_refined = True | |
| logger.debug(f"Refinement completed in {refinement_time:.2f}s") | |
| total_time = time.time() - start_time | |
| # Update conversation history | |
| self._update_history(user_message, final_response) | |
| # Cache the response | |
| if self.cache and use_cache: | |
| self.cache.set(cache_key, final_response) | |
| return PipelineResponse( | |
| final_response=final_response, | |
| draft_response=draft_response, | |
| was_refined=was_refined, | |
| voice_model_time=voice_time, | |
| refinement_time=refinement_time, | |
| total_time=total_time, | |
| ) | |
| def generate_stream( | |
| self, | |
| user_message: str, | |
| skip_refinement: bool = False, | |
| voice_temperature: float = 0.7, | |
| max_new_tokens: int = 1024, | |
| ) -> Iterator[str]: | |
| """ | |
| Generate a streaming response (voice model only, no refinement). | |
| Note: Refinement is not supported in streaming mode because | |
| we need the complete draft to refine it. | |
| Args: | |
| user_message: User's input | |
| skip_refinement: Always True for streaming | |
| voice_temperature: Temperature | |
| max_new_tokens: Maximum tokens | |
| Yields: | |
| Token strings as generated | |
| """ | |
| # Stream from voice model | |
| for token in self.voice_model.generate_stream( | |
| user_message=user_message, | |
| conversation_history=self.conversation_history, | |
| max_new_tokens=max_new_tokens, | |
| temperature=voice_temperature, | |
| ): | |
| yield token | |
| def generate_ab_test( | |
| self, | |
| user_message: str, | |
| voice_temperature: float = 0.7, | |
| max_new_tokens: int = 1024, | |
| ) -> dict: | |
| """ | |
| Generate both refined and unrefined responses for comparison. | |
| Args: | |
| user_message: User's input | |
| voice_temperature: Temperature | |
| max_new_tokens: Maximum tokens | |
| Returns: | |
| Dictionary with 'draft', 'refined', and timing info | |
| """ | |
| start_time = time.time() | |
| # Generate draft | |
| voice_start = time.time() | |
| draft = self.voice_model.generate( | |
| user_message=user_message, | |
| conversation_history=self.conversation_history, | |
| max_new_tokens=max_new_tokens, | |
| temperature=voice_temperature, | |
| ) | |
| voice_time = time.time() - voice_start | |
| # Generate refined (if available) | |
| refined = draft | |
| refinement_time = 0 | |
| if self.refinement_model: | |
| refine_start = time.time() | |
| refined = self.refinement_model.refine(draft) | |
| refinement_time = time.time() - refine_start | |
| return { | |
| "draft": draft, | |
| "refined": refined, | |
| "voice_time": voice_time, | |
| "refinement_time": refinement_time, | |
| "total_time": time.time() - start_time, | |
| "different": draft != refined, | |
| } | |
| def _get_cache_key(self, user_message: str) -> str: | |
| """Generate a cache key from message only.""" | |
| # Cache based on message content alone for better hit rate | |
| # Context-dependent responses are handled by the model at generation time | |
| return hashlib.md5(user_message.encode()).hexdigest() | |
| def _update_history(self, user_message: str, response: str) -> None: | |
| """Update conversation history.""" | |
| self.conversation_history.append({ | |
| "role": "user", | |
| "content": user_message, | |
| }) | |
| self.conversation_history.append({ | |
| "role": "assistant", | |
| "content": response, | |
| }) | |
| # Trim history | |
| max_messages = self.max_history_turns * 2 | |
| if len(self.conversation_history) > max_messages: | |
| self.conversation_history = self.conversation_history[-max_messages:] | |
| def clear_history(self) -> None: | |
| """Clear conversation history.""" | |
| self.conversation_history = [] | |
| logger.info("Conversation history cleared") | |
| def clear_cache(self) -> None: | |
| """Clear response cache.""" | |
| if self.cache: | |
| self.cache.clear() | |
| logger.info("Response cache cleared") | |
| def get_stats(self) -> dict: | |
| """Get pipeline statistics.""" | |
| return { | |
| "total_requests": self._total_requests, | |
| "cache_hits": self._cache_hits, | |
| "cache_hit_rate": ( | |
| self._cache_hits / self._total_requests | |
| if self._total_requests > 0 | |
| else 0 | |
| ), | |
| "history_length": len(self.conversation_history), | |
| "refinement_available": self.refinement_model is not None, | |
| "refinement_enabled": self._refinement_enabled, | |
| } | |
| def set_max_history_turns(self, turns: int) -> None: | |
| """Set maximum conversation history turns.""" | |
| self.max_history_turns = turns | |
| def enable_refinement(self, enable: bool = True) -> None: | |
| """Enable or disable refinement stage.""" | |
| if enable and self.refinement_model is None: | |
| logger.warning("No refinement model loaded") | |
| self._refinement_enabled = enable | |
| def main(): | |
| """CLI entry point for testing the pipeline.""" | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description="Test the dual LLM pipeline", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| python dual_llm_pipeline.py --voice-model username/model --prompt "Question?" | |
| python dual_llm_pipeline.py --voice-model username/model --interactive | |
| """, | |
| ) | |
| parser.add_argument("--voice-model", required=True, help="Voice model ID") | |
| parser.add_argument("--voice-adapter", help="Voice adapter ID") | |
| parser.add_argument("--refinement-model", help="Refinement model ID") | |
| parser.add_argument("--prompt", help="Single prompt to process") | |
| parser.add_argument("--interactive", action="store_true", help="Interactive mode") | |
| parser.add_argument("--no-refinement", action="store_true", help="Skip refinement") | |
| parser.add_argument("--no-4bit", action="store_true", help="Disable 4-bit") | |
| parser.add_argument("--ab-test", action="store_true", help="Show both versions") | |
| args = parser.parse_args() | |
| # Load pipeline | |
| print("Loading pipeline...") | |
| pipeline = DualLLMPipeline.from_hub( | |
| voice_model_id=args.voice_model, | |
| voice_adapter_id=args.voice_adapter, | |
| refinement_model_id=args.refinement_model, | |
| load_in_4bit=not args.no_4bit, | |
| enable_refinement=not args.no_refinement, | |
| ) | |
| if args.interactive: | |
| # Interactive mode | |
| print("\nAI Executive Chatbot") | |
| print("Type 'quit' to exit, 'clear' to clear history\n") | |
| while True: | |
| try: | |
| user_input = input("You: ").strip() | |
| except (EOFError, KeyboardInterrupt): | |
| break | |
| if not user_input: | |
| continue | |
| if user_input.lower() == "quit": | |
| break | |
| if user_input.lower() == "clear": | |
| pipeline.clear_history() | |
| print("History cleared.\n") | |
| continue | |
| response = pipeline.generate(user_input, skip_refinement=args.no_refinement) | |
| print(f"\nCEO: {response.final_response}") | |
| print(f"[Time: {response.total_time:.2f}s, Refined: {response.was_refined}]\n") | |
| elif args.prompt: | |
| # Single prompt | |
| if args.ab_test: | |
| result = pipeline.generate_ab_test(args.prompt) | |
| print("\n=== Draft Response ===") | |
| print(result["draft"]) | |
| print(f"\n[Voice model time: {result['voice_time']:.2f}s]") | |
| print("\n=== Refined Response ===") | |
| print(result["refined"]) | |
| print(f"\n[Refinement time: {result['refinement_time']:.2f}s]") | |
| print(f"[Total time: {result['total_time']:.2f}s]") | |
| print(f"[Changed: {result['different']}]") | |
| else: | |
| response = pipeline.generate(args.prompt, skip_refinement=args.no_refinement) | |
| print(f"\nResponse: {response.final_response}") | |
| print(f"\n[Total time: {response.total_time:.2f}s]") | |
| print(f"[Voice: {response.voice_model_time:.2f}s, Refine: {response.refinement_time:.2f}s]") | |
| else: | |
| print("Provide --prompt or --interactive") | |
| return 1 | |
| return 0 | |
| if __name__ == "__main__": | |
| exit(main()) | |