advanced-multi-model-orchestrator-v2 / advanced_orchestrator.py
kunaliitkgp09's picture
Upload advanced_orchestrator.py with huggingface_hub
0152b60 verified
#!/usr/bin/env python3
"""
Advanced Multi-Model Orchestrator with Parent LLM Reasoning
This version uses a parent LLM to intelligently analyze user requests and route them
to the most appropriate child model based on reasoning rather than simple heuristics.
"""
import asyncio
import json
import time
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Optional, Union, Any
import logging
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
AutoProcessor,
pipeline
)
from diffusers import StableDiffusionPipeline
from PIL import Image
import torch
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TaskType(Enum):
"""Task types that the parent LLM can route to"""
TEXT = "TEXT" # Text summarization, Q&A, text processing
CAPTION = "CAPTION" # Image captioning
TEXT2IMG = "TEXT2IMG" # Text-to-image generation
MULTIMODAL = "MULTIMODAL" # Complex multi-modal tasks
@dataclass
class ModelConfig:
"""Configuration for child models"""
name: str
model_type: TaskType
device: str = "cuda" if torch.cuda.is_available() else "cpu"
max_length: int = 512
temperature: float = 0.7
@dataclass
class TaskResult:
"""Result from a task execution"""
task_type: TaskType
input_data: str
output: Any
processing_time: float
confidence: float
reasoning: str
timestamp: float
error: Optional[str] = None
class ParentLLMRouter:
"""
Parent LLM that uses reasoning to route tasks to appropriate child models
"""
def __init__(self, model_name: str = "distilgpt2", device: str = None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.model_name = model_name
# Load the parent LLM for routing decisions
logger.info(f"Loading parent LLM: {model_name}")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
# Add padding token if not present
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
# Routing prompt template
self.routing_prompt_template = """You are a router. Analyze this user request and choose the best model:
- "TEXT" for text summarization, Q&A, or text processing
- "CAPTION" for describing images
- "TEXT2IMG" for generating images from text
- "MULTIMODAL" for complex tasks requiring multiple models
Respond only with one keyword: TEXT, CAPTION, TEXT2IMG, or MULTIMODAL.
User request: {user_request}
Response:"""
def analyze_request(self, user_request: str) -> Dict[str, Any]:
"""
Use the parent LLM to analyze the request and determine the best routing
"""
try:
# Create the routing prompt
prompt = self.routing_prompt_template.format(user_request=user_request)
# Tokenize the prompt
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate response
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=10,
temperature=0.1, # Low temperature for consistent routing
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id
)
# Decode the response
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the routing decision
routing_decision = self._extract_routing_decision(response, user_request)
return {
"task_type": routing_decision["task_type"],
"confidence": routing_decision["confidence"],
"reasoning": routing_decision["reasoning"],
"raw_response": response
}
except Exception as e:
logger.error(f"Error in parent LLM routing: {e}")
# Fallback to heuristic routing
return self._fallback_routing(user_request)
def _extract_routing_decision(self, response: str, user_request: str) -> Dict[str, Any]:
"""
Extract the routing decision from the LLM response
"""
# Look for the routing keywords in the response
response_upper = response.upper()
# Define keywords that indicate each task type
text_keywords = ["TEXT", "SUMMARIZE", "QUESTION", "ANSWER", "PROCESS"]
caption_keywords = ["CAPTION", "DESCRIBE", "IMAGE", "PICTURE", "PHOTO"]
text2img_keywords = ["TEXT2IMG", "GENERATE", "CREATE", "DRAW", "PAINT"]
multimodal_keywords = ["MULTIMODAL", "BOTH", "COMBINE", "TOGETHER"]
# Count keyword matches
text_score = sum(1 for keyword in text_keywords if keyword in response_upper)
caption_score = sum(1 for keyword in caption_keywords if keyword in response_upper)
text2img_score = sum(1 for keyword in text2img_keywords if keyword in response_upper)
multimodal_score = sum(1 for keyword in multimodal_keywords if keyword in response_upper)
# Determine the task type based on scores
scores = {
TaskType.TEXT: text_score,
TaskType.CAPTION: caption_score,
TaskType.TEXT2IMG: text2img_score,
TaskType.MULTIMODAL: multimodal_score
}
# Find the task type with the highest score
task_type = max(scores, key=scores.get)
max_score = scores[task_type]
# Calculate confidence based on score difference
total_score = sum(scores.values())
confidence = max_score / total_score if total_score > 0 else 0.25
# Generate reasoning
reasoning = f"Parent LLM analyzed request and determined {task_type.value} task with {confidence:.2f} confidence"
return {
"task_type": task_type,
"confidence": confidence,
"reasoning": reasoning
}
def _fallback_routing(self, user_request: str) -> Dict[str, Any]:
"""
Fallback routing using simple heuristics when LLM fails
"""
user_request_lower = user_request.lower()
# Simple keyword-based routing
if any(word in user_request_lower for word in ["image", "picture", "photo", "describe", "caption"]):
task_type = TaskType.CAPTION
reasoning = "Fallback: Detected image-related keywords"
elif any(word in user_request_lower for word in ["generate", "create", "draw", "paint", "image from"]):
task_type = TaskType.TEXT2IMG
reasoning = "Fallback: Detected image generation keywords"
elif any(word in user_request_lower for word in ["summarize", "question", "answer", "text"]):
task_type = TaskType.TEXT
reasoning = "Fallback: Detected text processing keywords"
else:
task_type = TaskType.TEXT
reasoning = "Fallback: Default to text processing"
return {
"task_type": task_type,
"confidence": 0.5, # Lower confidence for fallback
"reasoning": reasoning,
"raw_response": "Fallback routing used"
}
class AdvancedChildModel:
"""Base class for child models with advanced capabilities"""
def __init__(self, config: ModelConfig):
self.config = config
self.model = None
self.processor = None
self.is_loaded = False
async def load_model(self):
"""Load the model asynchronously"""
if self.is_loaded:
return
try:
logger.info(f"Loading {self.config.model_type.value} model: {self.config.name}")
if self.config.model_type == TaskType.CAPTION:
self.processor = AutoProcessor.from_pretrained(self.config.name)
self.model = AutoModelForCausalLM.from_pretrained(self.config.name).to(self.config.device)
elif self.config.model_type == TaskType.TEXT2IMG:
self.model = StableDiffusionPipeline.from_pretrained(self.config.name).to(self.config.device)
elif self.config.model_type == TaskType.TEXT:
self.tokenizer = AutoTokenizer.from_pretrained(self.config.name)
self.model = AutoModelForCausalLM.from_pretrained(self.config.name).to(self.config.device)
self.is_loaded = True
logger.info(f"Successfully loaded {self.config.model_type.value} model")
except Exception as e:
logger.error(f"Error loading {self.config.model_type.value} model: {e}")
raise
async def process(self, input_data: str, **kwargs) -> TaskResult:
"""Process the input and return a result"""
start_time = time.time()
try:
if not self.is_loaded:
await self.load_model()
# Process based on model type
if self.config.model_type == TaskType.CAPTION:
output = await self._process_caption(input_data)
elif self.config.model_type == TaskType.TEXT2IMG:
output = await self._process_text2img(input_data)
elif self.config.model_type == TaskType.TEXT:
output = await self._process_text(input_data)
else:
raise ValueError(f"Unknown model type: {self.config.model_type}")
processing_time = time.time() - start_time
return TaskResult(
task_type=self.config.model_type,
input_data=input_data,
output=output,
processing_time=processing_time,
confidence=0.9, # High confidence for successful processing
reasoning=f"Successfully processed with {self.config.model_type.value} model",
timestamp=time.time()
)
except Exception as e:
processing_time = time.time() - start_time
logger.error(f"Error processing with {self.config.model_type.value} model: {e}")
return TaskResult(
task_type=self.config.model_type,
input_data=input_data,
output=None,
processing_time=processing_time,
confidence=0.0,
reasoning=f"Error: {str(e)}",
timestamp=time.time(),
error=str(e)
)
async def _process_caption(self, image_path: str) -> str:
"""Process image captioning"""
image = Image.open(image_path).convert("RGB")
inputs = self.processor(images=image, return_tensors="pt").to(self.config.device)
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=self.config.max_length,
temperature=self.config.temperature
)
caption = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
return caption
async def _process_text2img(self, text_prompt: str) -> str:
"""Process text-to-image generation"""
image = self.model(text_prompt).images[0]
# Save the generated image
output_path = f"generated_image_{int(time.time())}.png"
image.save(output_path)
return output_path
async def _process_text(self, text_input: str) -> str:
"""Process text tasks (summarization, Q&A, etc.)"""
inputs = self.tokenizer(text_input, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(self.config.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=self.config.max_length,
temperature=self.config.temperature
)
result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return result
class AdvancedMultiModelOrchestrator:
"""
Advanced multi-model orchestrator with parent LLM reasoning
"""
def __init__(self, parent_model_name: str = "distilgpt2"):
self.parent_router = ParentLLMRouter(parent_model_name)
self.child_models: Dict[TaskType, AdvancedChildModel] = {}
self.task_history: List[TaskResult] = []
# Initialize child models
self._initialize_child_models()
def _initialize_child_models(self):
"""Initialize child models with configurations"""
model_configs = {
TaskType.CAPTION: ModelConfig(
name="kunaliitkgp09/clip-gpt2-image-captioner",
model_type=TaskType.CAPTION
),
TaskType.TEXT2IMG: ModelConfig(
name="kunaliitkgp09/flickr30k-text-to-image",
model_type=TaskType.TEXT2IMG
),
TaskType.TEXT: ModelConfig(
name="distilgpt2", # Using distilgpt2 for text processing
model_type=TaskType.TEXT
)
}
for task_type, config in model_configs.items():
self.child_models[task_type] = AdvancedChildModel(config)
async def process_request(self, user_request: str) -> TaskResult:
"""
Process a user request using parent LLM reasoning
"""
logger.info(f"Processing request: {user_request}")
# Step 1: Parent LLM analyzes the request
routing_decision = self.parent_router.analyze_request(user_request)
task_type = routing_decision["task_type"]
confidence = routing_decision["confidence"]
reasoning = routing_decision["reasoning"]
logger.info(f"Parent LLM routing decision: {task_type.value} (confidence: {confidence:.2f})")
logger.info(f"Reasoning: {reasoning}")
# Step 2: Route to appropriate child model
if task_type in self.child_models:
child_model = self.child_models[task_type]
result = await child_model.process(user_request)
# Update result with parent LLM reasoning
result.confidence = confidence
result.reasoning = f"Parent LLM: {reasoning}. Child model: {result.reasoning}"
else:
# Handle unknown task type
result = TaskResult(
task_type=task_type,
input_data=user_request,
output=None,
processing_time=0.0,
confidence=0.0,
reasoning=f"Unknown task type: {task_type.value}",
timestamp=time.time(),
error=f"No child model available for {task_type.value}"
)
# Step 3: Log the task
self.task_history.append(result)
return result
async def process_multimodal_request(self, image_path: str, text_prompt: str) -> Dict[str, TaskResult]:
"""
Process a complex multimodal request requiring multiple models
"""
logger.info(f"Processing multimodal request: image={image_path}, text={text_prompt}")
results = {}
# Process image captioning
if TaskType.CAPTION in self.child_models:
caption_result = await self.child_models[TaskType.CAPTION].process(image_path)
results["caption"] = caption_result
# Process text-to-image generation
if TaskType.TEXT2IMG in self.child_models:
text2img_result = await self.child_models[TaskType.TEXT2IMG].process(text_prompt)
results["generated_image"] = text2img_result
# Log all results
for result in results.values():
self.task_history.append(result)
return results
def get_task_history(self) -> List[TaskResult]:
"""Get the task execution history"""
return self.task_history
def get_performance_stats(self) -> Dict[str, Any]:
"""Get performance statistics"""
if not self.task_history:
return {}
total_tasks = len(self.task_history)
successful_tasks = len([t for t in self.task_history if t.error is None])
avg_processing_time = sum(t.processing_time for t in self.task_history) / total_tasks
task_type_counts = {}
for task in self.task_history:
task_type = task.task_type.value
task_type_counts[task_type] = task_type_counts.get(task_type, 0) + 1
return {
"total_tasks": total_tasks,
"successful_tasks": successful_tasks,
"success_rate": successful_tasks / total_tasks,
"average_processing_time": avg_processing_time,
"task_type_distribution": task_type_counts
}
# Demo and testing functions
async def demo_advanced_orchestrator():
"""Demo the advanced orchestrator"""
print("🚀 Advanced Multi-Model Orchestrator Demo")
print("=" * 50)
# Initialize the orchestrator
orchestrator = AdvancedMultiModelOrchestrator()
# Test requests
test_requests = [
"Summarize this text about artificial intelligence",
"Describe this image of a sunset",
"Generate an image of a peaceful forest",
"What is machine learning?",
"Create a picture of a futuristic city"
]
print("\n📝 Testing Parent LLM Routing:")
for request in test_requests:
print(f"\nRequest: {request}")
# Get routing decision only
routing_decision = orchestrator.parent_router.analyze_request(request)
print(f"Routing: {routing_decision['task_type'].value}")
print(f"Confidence: {routing_decision['confidence']:.2f}")
print(f"Reasoning: {routing_decision['reasoning']}")
print("\n" + "=" * 50)
print("✅ Demo completed!")
if __name__ == "__main__":
asyncio.run(demo_advanced_orchestrator())