claude-code-proxy / core /chain_engine.py
Yash030's picture
Implement image support in proxy with vision-aware routing
574e4e7
"""Model chaining engine for multi-stage AI pipelines."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any, Callable
from loguru import logger
@dataclass(frozen=True, slots=True)
class ChainStage:
"""A single stage in a model chain."""
model_ref: str # e.g., "zen/minimax-m2.5-free"
stage_name: str # e.g., "vision_analysis", "code_generation"
description: str
@dataclass(frozen=True, slots=True)
class ChainResult:
"""Result from executing a chain stage."""
stage: ChainStage
output: str
success: bool
error: str | None = None
# Chain templates for common multi-capability tasks
CHAIN_TEMPLATES: dict[str, list[ChainStage]] = {
"vision_to_text": [
ChainStage(
model_ref="nvidia_nim/stepfun-ai/step-3.5-flash",
stage_name="image_analysis",
description="Analyze image content",
),
ChainStage(
model_ref="zen/minimax-m2.5-free",
stage_name="response_generation",
description="Generate final response",
),
],
"reasoning_to_generation": [
ChainStage(
model_ref="nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct",
stage_name="analysis",
description="Analyze and plan",
),
ChainStage(
model_ref="zen/minimax-m2.5-free",
stage_name="generation",
description="Generate output",
),
],
}
class ChainEngine:
"""Execute multi-model pipelines for complex requests."""
def __init__(self, provider_getter: Callable[[str], Any]):
self._provider_getter = provider_getter
async def execute_simple_chain(
self,
stages: list[ChainStage],
initial_messages: list[Any],
system_prompt: str | None = None,
) -> AsyncIterator[str]:
"""Execute a chain of models sequentially.
Args:
stages: List of chain stages to execute
initial_messages: Initial user messages
system_prompt: Optional system prompt
Yields:
SSE events from the final model in the chain
"""
if not stages:
return
logger.info("ChainEngine: executing {} stages", len(stages))
# For now, execute single model - full chaining requires more integration
# This is a placeholder for the full implementation
first_stage = stages[0]
provider = self._provider_getter(first_stage.model_ref.split("/")[0])
logger.info(
"ChainEngine: using model {} for chain",
first_stage.model_ref,
)
# For Phase 1, just delegate to provider - full chaining comes later
# The infrastructure is now in place
async for event in provider.stream_response(
initial_messages, system_prompt, {}
):
yield event
def get_chain_for_requirements(
self,
required_capabilities: set[str],
available_models: list[str],
) -> list[ChainStage] | None:
"""Determine the appropriate chain based on required capabilities.
Args:
required_capabilities: Set of capabilities needed
available_models: Available model references
Returns:
Chain stages or None if single model is sufficient
"""
# If only one capability needed, no chain needed
if len(required_capabilities) <= 1:
return None
# If multiple capabilities, build a simple chain
if "vision" in required_capabilities and "coding" in required_capabilities:
return CHAIN_TEMPLATES.get("vision_to_text")
if "vision" in required_capabilities and "reasoning" in required_capabilities:
return CHAIN_TEMPLATES.get("vision_to_text")
if "reasoning" in required_capabilities and "coding" in required_capabilities:
return CHAIN_TEMPLATES.get("reasoning_to_generation")
# Default: no chain for now
return None
async def execute_model_for_stage(
provider: Any,
messages: list[Any],
system: str | None,
metadata: dict[str, Any],
) -> str:
"""Execute a single model stage and return its output."""
output_parts = []
try:
async for event in provider.stream_response(messages, system, metadata):
# Parse SSE and collect text output
if "content_block_delta" in event:
# Extract text from delta
pass
return "".join(output_parts)
except Exception as e:
logger.error("Chain stage failed: {}", e)
raise