| |
| """ |
| Advanced Multi-Model Orchestrator with Parent LLM Reasoning |
| |
| This version uses a parent LLM to intelligently analyze user requests and route them |
| to the most appropriate child model based on reasoning rather than simple heuristics. |
| """ |
|
|
| import asyncio |
| import json |
| import time |
| from dataclasses import dataclass |
| from enum import Enum |
| from typing import Dict, List, Optional, Union, Any |
| import logging |
|
|
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| AutoProcessor, |
| pipeline |
| ) |
| from diffusers import StableDiffusionPipeline |
| from PIL import Image |
| import torch |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| class TaskType(Enum): |
| """Task types that the parent LLM can route to""" |
| TEXT = "TEXT" |
| CAPTION = "CAPTION" |
| TEXT2IMG = "TEXT2IMG" |
| MULTIMODAL = "MULTIMODAL" |
|
|
| @dataclass |
| class ModelConfig: |
| """Configuration for child models""" |
| name: str |
| model_type: TaskType |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" |
| max_length: int = 512 |
| temperature: float = 0.7 |
|
|
| @dataclass |
| class TaskResult: |
| """Result from a task execution""" |
| task_type: TaskType |
| input_data: str |
| output: Any |
| processing_time: float |
| confidence: float |
| reasoning: str |
| timestamp: float |
| error: Optional[str] = None |
|
|
| class ParentLLMRouter: |
| """ |
| Parent LLM that uses reasoning to route tasks to appropriate child models |
| """ |
| |
| def __init__(self, model_name: str = "distilgpt2", device: str = None): |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
| self.model_name = model_name |
| |
| |
| logger.info(f"Loading parent LLM: {model_name}") |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device) |
| |
| |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| |
| |
| self.routing_prompt_template = """You are a router. Analyze this user request and choose the best model: |
| - "TEXT" for text summarization, Q&A, or text processing |
| - "CAPTION" for describing images |
| - "TEXT2IMG" for generating images from text |
| - "MULTIMODAL" for complex tasks requiring multiple models |
| |
| Respond only with one keyword: TEXT, CAPTION, TEXT2IMG, or MULTIMODAL. |
| |
| User request: {user_request} |
| Response:""" |
|
|
| def analyze_request(self, user_request: str) -> Dict[str, Any]: |
| """ |
| Use the parent LLM to analyze the request and determine the best routing |
| """ |
| try: |
| |
| prompt = self.routing_prompt_template.format(user_request=user_request) |
| |
| |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=10, |
| temperature=0.1, |
| do_sample=True, |
| pad_token_id=self.tokenizer.eos_token_id |
| ) |
| |
| |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| routing_decision = self._extract_routing_decision(response, user_request) |
| |
| return { |
| "task_type": routing_decision["task_type"], |
| "confidence": routing_decision["confidence"], |
| "reasoning": routing_decision["reasoning"], |
| "raw_response": response |
| } |
| |
| except Exception as e: |
| logger.error(f"Error in parent LLM routing: {e}") |
| |
| return self._fallback_routing(user_request) |
| |
| def _extract_routing_decision(self, response: str, user_request: str) -> Dict[str, Any]: |
| """ |
| Extract the routing decision from the LLM response |
| """ |
| |
| response_upper = response.upper() |
| |
| |
| text_keywords = ["TEXT", "SUMMARIZE", "QUESTION", "ANSWER", "PROCESS"] |
| caption_keywords = ["CAPTION", "DESCRIBE", "IMAGE", "PICTURE", "PHOTO"] |
| text2img_keywords = ["TEXT2IMG", "GENERATE", "CREATE", "DRAW", "PAINT"] |
| multimodal_keywords = ["MULTIMODAL", "BOTH", "COMBINE", "TOGETHER"] |
| |
| |
| text_score = sum(1 for keyword in text_keywords if keyword in response_upper) |
| caption_score = sum(1 for keyword in caption_keywords if keyword in response_upper) |
| text2img_score = sum(1 for keyword in text2img_keywords if keyword in response_upper) |
| multimodal_score = sum(1 for keyword in multimodal_keywords if keyword in response_upper) |
| |
| |
| scores = { |
| TaskType.TEXT: text_score, |
| TaskType.CAPTION: caption_score, |
| TaskType.TEXT2IMG: text2img_score, |
| TaskType.MULTIMODAL: multimodal_score |
| } |
| |
| |
| task_type = max(scores, key=scores.get) |
| max_score = scores[task_type] |
| |
| |
| total_score = sum(scores.values()) |
| confidence = max_score / total_score if total_score > 0 else 0.25 |
| |
| |
| reasoning = f"Parent LLM analyzed request and determined {task_type.value} task with {confidence:.2f} confidence" |
| |
| return { |
| "task_type": task_type, |
| "confidence": confidence, |
| "reasoning": reasoning |
| } |
| |
| def _fallback_routing(self, user_request: str) -> Dict[str, Any]: |
| """ |
| Fallback routing using simple heuristics when LLM fails |
| """ |
| user_request_lower = user_request.lower() |
| |
| |
| if any(word in user_request_lower for word in ["image", "picture", "photo", "describe", "caption"]): |
| task_type = TaskType.CAPTION |
| reasoning = "Fallback: Detected image-related keywords" |
| elif any(word in user_request_lower for word in ["generate", "create", "draw", "paint", "image from"]): |
| task_type = TaskType.TEXT2IMG |
| reasoning = "Fallback: Detected image generation keywords" |
| elif any(word in user_request_lower for word in ["summarize", "question", "answer", "text"]): |
| task_type = TaskType.TEXT |
| reasoning = "Fallback: Detected text processing keywords" |
| else: |
| task_type = TaskType.TEXT |
| reasoning = "Fallback: Default to text processing" |
| |
| return { |
| "task_type": task_type, |
| "confidence": 0.5, |
| "reasoning": reasoning, |
| "raw_response": "Fallback routing used" |
| } |
|
|
| class AdvancedChildModel: |
| """Base class for child models with advanced capabilities""" |
| |
| def __init__(self, config: ModelConfig): |
| self.config = config |
| self.model = None |
| self.processor = None |
| self.is_loaded = False |
| |
| async def load_model(self): |
| """Load the model asynchronously""" |
| if self.is_loaded: |
| return |
| |
| try: |
| logger.info(f"Loading {self.config.model_type.value} model: {self.config.name}") |
| |
| if self.config.model_type == TaskType.CAPTION: |
| self.processor = AutoProcessor.from_pretrained(self.config.name) |
| self.model = AutoModelForCausalLM.from_pretrained(self.config.name).to(self.config.device) |
| elif self.config.model_type == TaskType.TEXT2IMG: |
| self.model = StableDiffusionPipeline.from_pretrained(self.config.name).to(self.config.device) |
| elif self.config.model_type == TaskType.TEXT: |
| self.tokenizer = AutoTokenizer.from_pretrained(self.config.name) |
| self.model = AutoModelForCausalLM.from_pretrained(self.config.name).to(self.config.device) |
| |
| self.is_loaded = True |
| logger.info(f"Successfully loaded {self.config.model_type.value} model") |
| |
| except Exception as e: |
| logger.error(f"Error loading {self.config.model_type.value} model: {e}") |
| raise |
| |
| async def process(self, input_data: str, **kwargs) -> TaskResult: |
| """Process the input and return a result""" |
| start_time = time.time() |
| |
| try: |
| if not self.is_loaded: |
| await self.load_model() |
| |
| |
| if self.config.model_type == TaskType.CAPTION: |
| output = await self._process_caption(input_data) |
| elif self.config.model_type == TaskType.TEXT2IMG: |
| output = await self._process_text2img(input_data) |
| elif self.config.model_type == TaskType.TEXT: |
| output = await self._process_text(input_data) |
| else: |
| raise ValueError(f"Unknown model type: {self.config.model_type}") |
| |
| processing_time = time.time() - start_time |
| |
| return TaskResult( |
| task_type=self.config.model_type, |
| input_data=input_data, |
| output=output, |
| processing_time=processing_time, |
| confidence=0.9, |
| reasoning=f"Successfully processed with {self.config.model_type.value} model", |
| timestamp=time.time() |
| ) |
| |
| except Exception as e: |
| processing_time = time.time() - start_time |
| logger.error(f"Error processing with {self.config.model_type.value} model: {e}") |
| |
| return TaskResult( |
| task_type=self.config.model_type, |
| input_data=input_data, |
| output=None, |
| processing_time=processing_time, |
| confidence=0.0, |
| reasoning=f"Error: {str(e)}", |
| timestamp=time.time(), |
| error=str(e) |
| ) |
| |
| async def _process_caption(self, image_path: str) -> str: |
| """Process image captioning""" |
| image = Image.open(image_path).convert("RGB") |
| inputs = self.processor(images=image, return_tensors="pt").to(self.config.device) |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_length=self.config.max_length, |
| temperature=self.config.temperature |
| ) |
| |
| caption = self.processor.batch_decode(outputs, skip_special_tokens=True)[0] |
| return caption |
| |
| async def _process_text2img(self, text_prompt: str) -> str: |
| """Process text-to-image generation""" |
| image = self.model(text_prompt).images[0] |
| |
| |
| output_path = f"generated_image_{int(time.time())}.png" |
| image.save(output_path) |
| return output_path |
| |
| async def _process_text(self, text_input: str) -> str: |
| """Process text tasks (summarization, Q&A, etc.)""" |
| inputs = self.tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512) |
| inputs = {k: v.to(self.config.device) for k, v in inputs.items()} |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_length=self.config.max_length, |
| temperature=self.config.temperature |
| ) |
| |
| result = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| return result |
|
|
| class AdvancedMultiModelOrchestrator: |
| """ |
| Advanced multi-model orchestrator with parent LLM reasoning |
| """ |
| |
| def __init__(self, parent_model_name: str = "distilgpt2"): |
| self.parent_router = ParentLLMRouter(parent_model_name) |
| self.child_models: Dict[TaskType, AdvancedChildModel] = {} |
| self.task_history: List[TaskResult] = [] |
| |
| |
| self._initialize_child_models() |
| |
| def _initialize_child_models(self): |
| """Initialize child models with configurations""" |
| model_configs = { |
| TaskType.CAPTION: ModelConfig( |
| name="kunaliitkgp09/clip-gpt2-image-captioner", |
| model_type=TaskType.CAPTION |
| ), |
| TaskType.TEXT2IMG: ModelConfig( |
| name="kunaliitkgp09/flickr30k-text-to-image", |
| model_type=TaskType.TEXT2IMG |
| ), |
| TaskType.TEXT: ModelConfig( |
| name="distilgpt2", |
| model_type=TaskType.TEXT |
| ) |
| } |
| |
| for task_type, config in model_configs.items(): |
| self.child_models[task_type] = AdvancedChildModel(config) |
| |
| async def process_request(self, user_request: str) -> TaskResult: |
| """ |
| Process a user request using parent LLM reasoning |
| """ |
| logger.info(f"Processing request: {user_request}") |
| |
| |
| routing_decision = self.parent_router.analyze_request(user_request) |
| task_type = routing_decision["task_type"] |
| confidence = routing_decision["confidence"] |
| reasoning = routing_decision["reasoning"] |
| |
| logger.info(f"Parent LLM routing decision: {task_type.value} (confidence: {confidence:.2f})") |
| logger.info(f"Reasoning: {reasoning}") |
| |
| |
| if task_type in self.child_models: |
| child_model = self.child_models[task_type] |
| result = await child_model.process(user_request) |
| |
| |
| result.confidence = confidence |
| result.reasoning = f"Parent LLM: {reasoning}. Child model: {result.reasoning}" |
| |
| else: |
| |
| result = TaskResult( |
| task_type=task_type, |
| input_data=user_request, |
| output=None, |
| processing_time=0.0, |
| confidence=0.0, |
| reasoning=f"Unknown task type: {task_type.value}", |
| timestamp=time.time(), |
| error=f"No child model available for {task_type.value}" |
| ) |
| |
| |
| self.task_history.append(result) |
| |
| return result |
| |
| async def process_multimodal_request(self, image_path: str, text_prompt: str) -> Dict[str, TaskResult]: |
| """ |
| Process a complex multimodal request requiring multiple models |
| """ |
| logger.info(f"Processing multimodal request: image={image_path}, text={text_prompt}") |
| |
| results = {} |
| |
| |
| if TaskType.CAPTION in self.child_models: |
| caption_result = await self.child_models[TaskType.CAPTION].process(image_path) |
| results["caption"] = caption_result |
| |
| |
| if TaskType.TEXT2IMG in self.child_models: |
| text2img_result = await self.child_models[TaskType.TEXT2IMG].process(text_prompt) |
| results["generated_image"] = text2img_result |
| |
| |
| for result in results.values(): |
| self.task_history.append(result) |
| |
| return results |
| |
| def get_task_history(self) -> List[TaskResult]: |
| """Get the task execution history""" |
| return self.task_history |
| |
| def get_performance_stats(self) -> Dict[str, Any]: |
| """Get performance statistics""" |
| if not self.task_history: |
| return {} |
| |
| total_tasks = len(self.task_history) |
| successful_tasks = len([t for t in self.task_history if t.error is None]) |
| avg_processing_time = sum(t.processing_time for t in self.task_history) / total_tasks |
| |
| task_type_counts = {} |
| for task in self.task_history: |
| task_type = task.task_type.value |
| task_type_counts[task_type] = task_type_counts.get(task_type, 0) + 1 |
| |
| return { |
| "total_tasks": total_tasks, |
| "successful_tasks": successful_tasks, |
| "success_rate": successful_tasks / total_tasks, |
| "average_processing_time": avg_processing_time, |
| "task_type_distribution": task_type_counts |
| } |
|
|
| |
| async def demo_advanced_orchestrator(): |
| """Demo the advanced orchestrator""" |
| print("🚀 Advanced Multi-Model Orchestrator Demo") |
| print("=" * 50) |
| |
| |
| orchestrator = AdvancedMultiModelOrchestrator() |
| |
| |
| test_requests = [ |
| "Summarize this text about artificial intelligence", |
| "Describe this image of a sunset", |
| "Generate an image of a peaceful forest", |
| "What is machine learning?", |
| "Create a picture of a futuristic city" |
| ] |
| |
| print("\n📝 Testing Parent LLM Routing:") |
| for request in test_requests: |
| print(f"\nRequest: {request}") |
| |
| |
| routing_decision = orchestrator.parent_router.analyze_request(request) |
| print(f"Routing: {routing_decision['task_type'].value}") |
| print(f"Confidence: {routing_decision['confidence']:.2f}") |
| print(f"Reasoning: {routing_decision['reasoning']}") |
| |
| print("\n" + "=" * 50) |
| print("✅ Demo completed!") |
|
|
| if __name__ == "__main__": |
| asyncio.run(demo_advanced_orchestrator()) |