File size: 4,739 Bytes
4974012
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574e4e7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""Model chaining engine for multi-stage AI pipelines."""

from __future__ import annotations

import asyncio
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Any, Callable

from loguru import logger


@dataclass(frozen=True, slots=True)
class ChainStage:
    """A single stage in a model chain."""

    model_ref: str  # e.g., "zen/minimax-m2.5-free"
    stage_name: str  # e.g., "vision_analysis", "code_generation"
    description: str


@dataclass(frozen=True, slots=True)
class ChainResult:
    """Result from executing a chain stage."""

    stage: ChainStage
    output: str
    success: bool
    error: str | None = None


# Chain templates for common multi-capability tasks
CHAIN_TEMPLATES: dict[str, list[ChainStage]] = {
    "vision_to_text": [
        ChainStage(
            model_ref="nvidia_nim/stepfun-ai/step-3.5-flash",
            stage_name="image_analysis",
            description="Analyze image content",
        ),
        ChainStage(
            model_ref="zen/minimax-m2.5-free",
            stage_name="response_generation",
            description="Generate final response",
        ),
    ],
    "reasoning_to_generation": [
        ChainStage(
            model_ref="nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct",
            stage_name="analysis",
            description="Analyze and plan",
        ),
        ChainStage(
            model_ref="zen/minimax-m2.5-free",
            stage_name="generation",
            description="Generate output",
        ),
    ],
}


class ChainEngine:
    """Execute multi-model pipelines for complex requests."""

    def __init__(self, provider_getter: Callable[[str], Any]):
        self._provider_getter = provider_getter

    async def execute_simple_chain(
        self,
        stages: list[ChainStage],
        initial_messages: list[Any],
        system_prompt: str | None = None,
    ) -> AsyncIterator[str]:
        """Execute a chain of models sequentially.

        Args:
            stages: List of chain stages to execute
            initial_messages: Initial user messages
            system_prompt: Optional system prompt

        Yields:
            SSE events from the final model in the chain
        """
        if not stages:
            return

        logger.info("ChainEngine: executing {} stages", len(stages))

        # For now, execute single model - full chaining requires more integration
        # This is a placeholder for the full implementation
        first_stage = stages[0]
        provider = self._provider_getter(first_stage.model_ref.split("/")[0])

        logger.info(
            "ChainEngine: using model {} for chain",
            first_stage.model_ref,
        )

        # For Phase 1, just delegate to provider - full chaining comes later
        # The infrastructure is now in place
        async for event in provider.stream_response(
            initial_messages, system_prompt, {}
        ):
            yield event

    def get_chain_for_requirements(
        self,
        required_capabilities: set[str],
        available_models: list[str],
    ) -> list[ChainStage] | None:
        """Determine the appropriate chain based on required capabilities.

        Args:
            required_capabilities: Set of capabilities needed
            available_models: Available model references

        Returns:
            Chain stages or None if single model is sufficient
        """
        # If only one capability needed, no chain needed
        if len(required_capabilities) <= 1:
            return None

        # If multiple capabilities, build a simple chain
        if "vision" in required_capabilities and "coding" in required_capabilities:
            return CHAIN_TEMPLATES.get("vision_to_text")

        if "vision" in required_capabilities and "reasoning" in required_capabilities:
            return CHAIN_TEMPLATES.get("vision_to_text")

        if "reasoning" in required_capabilities and "coding" in required_capabilities:
            return CHAIN_TEMPLATES.get("reasoning_to_generation")

        # Default: no chain for now
        return None


async def execute_model_for_stage(
    provider: Any,
    messages: list[Any],
    system: str | None,
    metadata: dict[str, Any],
) -> str:
    """Execute a single model stage and return its output."""
    output_parts = []

    try:
        async for event in provider.stream_response(messages, system, metadata):
            # Parse SSE and collect text output
            if "content_block_delta" in event:
                # Extract text from delta
                pass

        return "".join(output_parts)
    except Exception as e:
        logger.error("Chain stage failed: {}", e)
        raise