VibecoderMcSwaggins commited on
Commit
580b270
Β·
unverified Β·
2 Parent(s): 949847c b9e710f

Merge pull request #137 from The-Obstacle-Is-The-Way/refactor/spec-21-middleware-architecture

Browse files
docs/specs/SPEC-21-MIDDLEWARE-ARCHITECTURE.md CHANGED
@@ -1,6 +1,6 @@
1
  # SPEC-21: Middleware Architecture Refactor
2
 
3
- **Status:** READY FOR IMPLEMENTATION
4
  **Priority:** P2 (Architectural hygiene + fixes HuggingFace retry bug)
5
  **Effort:** 2 hours
6
  **PR Scope:** Folder rename + new middleware implementations
 
1
  # SPEC-21: Middleware Architecture Refactor
2
 
3
+ **Status:** COMPLETED
4
  **Priority:** P2 (Architectural hygiene + fixes HuggingFace retry bug)
5
  **Effort:** 2 hours
6
  **PR Scope:** Folder rename + new middleware implementations
src/clients/huggingface.py CHANGED
@@ -27,6 +27,7 @@ from agent_framework._types import FunctionCallContent, FunctionResultContent
27
  from agent_framework.observability import use_observability
28
  from huggingface_hub import InferenceClient
29
 
 
30
  from src.utils.config import settings
31
 
32
  logger = structlog.get_logger()
@@ -51,7 +52,13 @@ class HuggingFaceChatClient(BaseChatClient): # type: ignore[misc]
51
  api_key: HF_TOKEN (optional, defaults to env var).
52
  **kwargs: Additional arguments passed to BaseChatClient.
53
  """
54
- super().__init__(**kwargs)
 
 
 
 
 
 
55
  # FIX: Use 7B model to stay on HuggingFace native infrastructure (avoid Novita 500s)
56
  self.model_id = model_id or settings.huggingface_model or "Qwen/Qwen2.5-7B-Instruct"
57
  self.api_key = api_key or settings.hf_token
 
27
  from agent_framework.observability import use_observability
28
  from huggingface_hub import InferenceClient
29
 
30
+ from src.middleware import RetryMiddleware, TokenTrackingMiddleware
31
  from src.utils.config import settings
32
 
33
  logger = structlog.get_logger()
 
52
  api_key: HF_TOKEN (optional, defaults to env var).
53
  **kwargs: Additional arguments passed to BaseChatClient.
54
  """
55
+ # Create middleware instances for retry and token tracking
56
+ middleware = [
57
+ RetryMiddleware(max_attempts=3, min_wait=1.0, max_wait=10.0),
58
+ TokenTrackingMiddleware(),
59
+ ]
60
+
61
+ super().__init__(middleware=middleware, **kwargs) # type: ignore[arg-type]
62
  # FIX: Use 7B model to stay on HuggingFace native infrastructure (avoid Novita 500s)
63
  self.model_id = model_id or settings.huggingface_model or "Qwen/Qwen2.5-7B-Instruct"
64
  self.api_key = api_key or settings.hf_token
src/middleware/__init__.py CHANGED
@@ -1 +1,10 @@
1
- """Middleware components for orchestration."""
 
 
 
 
 
 
 
 
 
 
1
+ """Microsoft Agent Framework middleware implementations.
2
+
3
+ These are interceptor-pattern middleware that wrap chat client calls.
4
+ They are NOT workflows - see src/workflows/ for orchestration patterns.
5
+ """
6
+
7
+ from src.middleware.retry import RetryMiddleware
8
+ from src.middleware.token_tracking import TokenTrackingMiddleware
9
+
10
+ __all__ = ["RetryMiddleware", "TokenTrackingMiddleware"]
src/middleware/retry.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Retry middleware for chat clients with exponential backoff."""
2
+
3
+ import asyncio
4
+ import random
5
+ from collections.abc import Awaitable, Callable
6
+
7
+ import structlog
8
+ from agent_framework._middleware import ChatContext, ChatMiddleware
9
+
10
+ logger = structlog.get_logger()
11
+
12
+
13
+ class RetryMiddleware(ChatMiddleware):
14
+ """Retries failed chat requests with exponential backoff.
15
+
16
+ This middleware intercepts chat client calls and retries on transient
17
+ errors (rate limits, timeouts, server errors).
18
+
19
+ Attributes:
20
+ max_attempts: Maximum number of attempts (default: 3)
21
+ min_wait: Minimum wait between retries in seconds (default: 1.0)
22
+ max_wait: Maximum wait between retries in seconds (default: 10.0)
23
+ retryable_status_codes: HTTP status codes to retry (default: 429, 500, 502, 503, 504)
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ max_attempts: int = 3,
29
+ min_wait: float = 1.0,
30
+ max_wait: float = 10.0,
31
+ retryable_status_codes: tuple[int, ...] = (429, 500, 502, 503, 504),
32
+ ) -> None:
33
+ self.max_attempts = max_attempts
34
+ self.min_wait = min_wait
35
+ self.max_wait = max_wait
36
+ self.retryable_status_codes = retryable_status_codes
37
+
38
+ def _is_retryable(self, error: Exception) -> bool:
39
+ """Check if error is retryable."""
40
+ # Check for httpx status errors
41
+ if hasattr(error, "response") and hasattr(error.response, "status_code"):
42
+ return error.response.status_code in self.retryable_status_codes
43
+
44
+ # Check for timeout errors
45
+ error_name = type(error).__name__.lower()
46
+ if "timeout" in error_name:
47
+ return True
48
+
49
+ # Check for connection errors
50
+ if "connection" in error_name:
51
+ return True
52
+
53
+ return False
54
+
55
+ def _calculate_wait(self, attempt: int) -> float:
56
+ """Calculate wait time with exponential backoff and jitter."""
57
+ wait = self.min_wait * (2**attempt)
58
+ wait = min(wait, self.max_wait)
59
+ # Add jitter (Β±25%) to avoid thundering herd
60
+ jitter = wait * 0.25 * (2 * random.random() - 1)
61
+ return float(max(self.min_wait, wait + jitter))
62
+
63
+ async def process(
64
+ self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
65
+ ) -> None:
66
+ """Process the chat request with retry logic."""
67
+ last_error: Exception | None = None
68
+
69
+ for attempt in range(self.max_attempts):
70
+ try:
71
+ await next(context)
72
+ return # Success - exit retry loop
73
+
74
+ except Exception as e:
75
+ last_error = e
76
+
77
+ if not self._is_retryable(e):
78
+ logger.warning(
79
+ "Non-retryable error",
80
+ error=str(e),
81
+ error_type=type(e).__name__,
82
+ )
83
+ raise # Don't retry non-retryable errors
84
+
85
+ if attempt < self.max_attempts - 1:
86
+ wait_time = self._calculate_wait(attempt)
87
+ logger.info(
88
+ "Retrying after error",
89
+ attempt=attempt + 1,
90
+ max_attempts=self.max_attempts,
91
+ wait_seconds=wait_time,
92
+ error=str(e),
93
+ )
94
+ await asyncio.sleep(wait_time)
95
+
96
+ # All retries exhausted
97
+ logger.error(
98
+ "All retry attempts failed",
99
+ max_attempts=self.max_attempts,
100
+ last_error=str(last_error),
101
+ )
102
+ if last_error:
103
+ raise last_error
src/middleware/token_tracking.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Token tracking middleware for monitoring API usage."""
2
+
3
+ from collections.abc import Awaitable, Callable
4
+
5
+ import structlog
6
+ from agent_framework._middleware import ChatContext, ChatMiddleware
7
+
8
+ logger = structlog.get_logger()
9
+
10
+
11
+ class TokenTrackingMiddleware(ChatMiddleware):
12
+ """Tracks token usage across chat requests.
13
+
14
+ This middleware logs token usage after each chat completion
15
+ and maintains running totals for the session.
16
+
17
+ Usage metrics are logged via structlog for observability.
18
+ """
19
+
20
+ def __init__(self) -> None:
21
+ self.total_input_tokens = 0
22
+ self.total_output_tokens = 0
23
+ self.request_count = 0
24
+
25
+ async def process(
26
+ self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
27
+ ) -> None:
28
+ """Process request and track token usage."""
29
+ await next(context)
30
+
31
+ # Extract usage from response if available
32
+ if context.result is None:
33
+ return
34
+
35
+ usage = None
36
+
37
+ # Try to get usage from response
38
+ if hasattr(context.result, "usage"):
39
+ usage = context.result.usage
40
+ elif hasattr(context.result, "messages") and context.result.messages:
41
+ # Check first message for usage metadata
42
+ msg = context.result.messages[0]
43
+ if hasattr(msg, "metadata") and msg.metadata:
44
+ usage = msg.metadata.get("usage")
45
+
46
+ if usage:
47
+ # Handle both dict-like and object attribute access
48
+ if hasattr(usage, "get"):
49
+ # Dict-like access
50
+ input_tokens = usage.get("input_tokens", 0) or usage.get("prompt_tokens", 0)
51
+ output_tokens = usage.get("output_tokens", 0) or usage.get("completion_tokens", 0)
52
+ else:
53
+ # Object attribute access (Pydantic models, etc.)
54
+ input_tokens = getattr(usage, "input_tokens", 0) or getattr(
55
+ usage, "prompt_tokens", 0
56
+ )
57
+ output_tokens = getattr(usage, "output_tokens", 0) or getattr(
58
+ usage, "completion_tokens", 0
59
+ )
60
+
61
+ self.total_input_tokens += input_tokens
62
+ self.total_output_tokens += output_tokens
63
+ self.request_count += 1
64
+
65
+ logger.info(
66
+ "Token usage",
67
+ request_input=input_tokens,
68
+ request_output=output_tokens,
69
+ total_input=self.total_input_tokens,
70
+ total_output=self.total_output_tokens,
71
+ total_requests=self.request_count,
72
+ )
73
+
74
+ def get_stats(self) -> dict[str, int]:
75
+ """Get cumulative token usage statistics.
76
+
77
+ Returns:
78
+ Dictionary with total_input, total_output, and request_count.
79
+ """
80
+ return {
81
+ "total_input": self.total_input_tokens,
82
+ "total_output": self.total_output_tokens,
83
+ "request_count": self.request_count,
84
+ }
src/orchestrators/hierarchical.py CHANGED
@@ -19,11 +19,11 @@ import structlog
19
  from src.agents.judge_agent_llm import LLMSubIterationJudge
20
  from src.agents.magentic_agents import create_search_agent
21
  from src.config.domain import ResearchDomain
22
- from src.middleware.sub_iteration import SubIterationMiddleware, SubIterationTeam
23
  from src.orchestrators.base import OrchestratorProtocol
24
  from src.state import init_magentic_state
25
  from src.utils.models import AgentEvent, OrchestratorConfig
26
  from src.utils.service_loader import get_embedding_service_if_available
 
27
 
28
  logger = structlog.get_logger()
29
 
 
19
  from src.agents.judge_agent_llm import LLMSubIterationJudge
20
  from src.agents.magentic_agents import create_search_agent
21
  from src.config.domain import ResearchDomain
 
22
  from src.orchestrators.base import OrchestratorProtocol
23
  from src.state import init_magentic_state
24
  from src.utils.models import AgentEvent, OrchestratorConfig
25
  from src.utils.service_loader import get_embedding_service_if_available
26
+ from src.workflows.sub_iteration import SubIterationMiddleware, SubIterationTeam
27
 
28
  logger = structlog.get_logger()
29
 
src/{middleware β†’ workflows}/.gitkeep RENAMED
File without changes
src/workflows/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Workflow components for orchestration.
2
+
3
+ These are workflow patterns (e.g., team→judge loops), NOT interceptor middleware.
4
+ For interceptor middleware, see src/middleware/.
5
+ """
6
+
7
+ from src.workflows.sub_iteration import (
8
+ SubIterationJudge,
9
+ SubIterationMiddleware,
10
+ SubIterationTeam,
11
+ )
12
+
13
+ __all__ = ["SubIterationJudge", "SubIterationMiddleware", "SubIterationTeam"]
src/{middleware β†’ workflows}/sub_iteration.py RENAMED
File without changes
tests/unit/middleware/__init__.py ADDED
File without changes
tests/unit/middleware/test_retry.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import AsyncMock, MagicMock
2
+
3
+ import pytest
4
+
5
+ from src.middleware.retry import RetryMiddleware
6
+
7
+ pytestmark = pytest.mark.unit
8
+
9
+
10
+ @pytest.mark.asyncio
11
+ async def test_retry_middleware_succeeds_first_try():
12
+ """RetryMiddleware should pass through on success."""
13
+ middleware = RetryMiddleware(max_attempts=3)
14
+ context = MagicMock()
15
+ next_fn = AsyncMock()
16
+
17
+ await middleware.process(context, next_fn)
18
+
19
+ next_fn.assert_called_once_with(context)
20
+
21
+
22
+ @pytest.mark.asyncio
23
+ async def test_retry_middleware_retries_on_429():
24
+ """RetryMiddleware should retry on 429 rate limit."""
25
+ middleware = RetryMiddleware(max_attempts=3, min_wait=0.01)
26
+ context = MagicMock()
27
+
28
+ # First two calls fail with 429, third succeeds
29
+ call_count = 0
30
+
31
+ async def mock_next(ctx):
32
+ nonlocal call_count
33
+ call_count += 1
34
+ if call_count < 3:
35
+ error = Exception("Rate limited")
36
+ error.response = MagicMock(status_code=429)
37
+ raise error
38
+
39
+ await middleware.process(context, mock_next)
40
+ assert call_count == 3
41
+
42
+
43
+ @pytest.mark.asyncio
44
+ async def test_retry_middleware_raises_after_max_attempts():
45
+ """RetryMiddleware should raise after max attempts exhausted."""
46
+ middleware = RetryMiddleware(max_attempts=2, min_wait=0.01)
47
+ context = MagicMock()
48
+
49
+ async def always_fails(ctx):
50
+ error = Exception("Always fails")
51
+ error.response = MagicMock(status_code=500)
52
+ raise error
53
+
54
+ with pytest.raises(Exception, match="Always fails"):
55
+ await middleware.process(context, always_fails)
tests/unit/middleware/test_token_tracking.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unittest.mock import AsyncMock, MagicMock
2
+
3
+ import pytest
4
+
5
+ from src.middleware.token_tracking import TokenTrackingMiddleware
6
+
7
+ pytestmark = pytest.mark.unit
8
+
9
+
10
+ @pytest.mark.asyncio
11
+ async def test_token_tracking_middleware_counts_tokens():
12
+ """TokenTrackingMiddleware should count tokens from response."""
13
+ middleware = TokenTrackingMiddleware()
14
+ context = MagicMock()
15
+
16
+ # Mock response with usage
17
+ context.result.usage = {"input_tokens": 10, "output_tokens": 20}
18
+
19
+ next_fn = AsyncMock()
20
+
21
+ await middleware.process(context, next_fn)
22
+
23
+ assert middleware.total_input_tokens == 10
24
+ assert middleware.total_output_tokens == 20
25
+ assert middleware.request_count == 1
26
+
27
+
28
+ @pytest.mark.asyncio
29
+ async def test_token_tracking_middleware_handles_no_usage():
30
+ """TokenTrackingMiddleware should handle response without usage gracefully."""
31
+ middleware = TokenTrackingMiddleware()
32
+ context = MagicMock()
33
+ context.result = MagicMock()
34
+ del context.result.usage # Ensure usage attr doesn't exist
35
+ context.result.messages = [] # Ensure no messages
36
+
37
+ next_fn = AsyncMock()
38
+
39
+ await middleware.process(context, next_fn)
40
+
41
+ assert middleware.total_input_tokens == 0
42
+ assert middleware.total_output_tokens == 0
43
+ assert middleware.request_count == 0
tests/unit/test_hierarchical.py CHANGED
@@ -4,8 +4,8 @@ from unittest.mock import AsyncMock
4
 
5
  import pytest
6
 
7
- from src.middleware.sub_iteration import SubIterationMiddleware
8
  from src.utils.models import AssessmentDetails, JudgeAssessment
 
9
 
10
  pytestmark = pytest.mark.unit
11
 
 
4
 
5
  import pytest
6
 
 
7
  from src.utils.models import AssessmentDetails, JudgeAssessment
8
+ from src.workflows.sub_iteration import SubIterationMiddleware
9
 
10
  pytestmark = pytest.mark.unit
11