Spaces:
Running
Running
File size: 4,739 Bytes
4974012 574e4e7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | """Model chaining engine for multi-stage AI pipelines."""
from __future__ import annotations
import asyncio
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any, Callable
from loguru import logger
@dataclass(frozen=True, slots=True)
class ChainStage:
"""A single stage in a model chain."""
model_ref: str # e.g., "zen/minimax-m2.5-free"
stage_name: str # e.g., "vision_analysis", "code_generation"
description: str
@dataclass(frozen=True, slots=True)
class ChainResult:
"""Result from executing a chain stage."""
stage: ChainStage
output: str
success: bool
error: str | None = None
# Chain templates for common multi-capability tasks
CHAIN_TEMPLATES: dict[str, list[ChainStage]] = {
"vision_to_text": [
ChainStage(
model_ref="nvidia_nim/stepfun-ai/step-3.5-flash",
stage_name="image_analysis",
description="Analyze image content",
),
ChainStage(
model_ref="zen/minimax-m2.5-free",
stage_name="response_generation",
description="Generate final response",
),
],
"reasoning_to_generation": [
ChainStage(
model_ref="nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct",
stage_name="analysis",
description="Analyze and plan",
),
ChainStage(
model_ref="zen/minimax-m2.5-free",
stage_name="generation",
description="Generate output",
),
],
}
class ChainEngine:
"""Execute multi-model pipelines for complex requests."""
def __init__(self, provider_getter: Callable[[str], Any]):
self._provider_getter = provider_getter
async def execute_simple_chain(
self,
stages: list[ChainStage],
initial_messages: list[Any],
system_prompt: str | None = None,
) -> AsyncIterator[str]:
"""Execute a chain of models sequentially.
Args:
stages: List of chain stages to execute
initial_messages: Initial user messages
system_prompt: Optional system prompt
Yields:
SSE events from the final model in the chain
"""
if not stages:
return
logger.info("ChainEngine: executing {} stages", len(stages))
# For now, execute single model - full chaining requires more integration
# This is a placeholder for the full implementation
first_stage = stages[0]
provider = self._provider_getter(first_stage.model_ref.split("/")[0])
logger.info(
"ChainEngine: using model {} for chain",
first_stage.model_ref,
)
# For Phase 1, just delegate to provider - full chaining comes later
# The infrastructure is now in place
async for event in provider.stream_response(
initial_messages, system_prompt, {}
):
yield event
def get_chain_for_requirements(
self,
required_capabilities: set[str],
available_models: list[str],
) -> list[ChainStage] | None:
"""Determine the appropriate chain based on required capabilities.
Args:
required_capabilities: Set of capabilities needed
available_models: Available model references
Returns:
Chain stages or None if single model is sufficient
"""
# If only one capability needed, no chain needed
if len(required_capabilities) <= 1:
return None
# If multiple capabilities, build a simple chain
if "vision" in required_capabilities and "coding" in required_capabilities:
return CHAIN_TEMPLATES.get("vision_to_text")
if "vision" in required_capabilities and "reasoning" in required_capabilities:
return CHAIN_TEMPLATES.get("vision_to_text")
if "reasoning" in required_capabilities and "coding" in required_capabilities:
return CHAIN_TEMPLATES.get("reasoning_to_generation")
# Default: no chain for now
return None
async def execute_model_for_stage(
provider: Any,
messages: list[Any],
system: str | None,
metadata: dict[str, Any],
) -> str:
"""Execute a single model stage and return its output."""
output_parts = []
try:
async for event in provider.stream_response(messages, system, metadata):
# Parse SSE and collect text output
if "content_block_delta" in event:
# Extract text from delta
pass
return "".join(output_parts)
except Exception as e:
logger.error("Chain stage failed: {}", e)
raise
|