VibecoderMcSwaggins commited on
Commit
b9e710f
·
1 Parent(s): 3572ba0

fix: address CodeRabbit review feedback for SPEC-21

Browse files

- token_tracking.py: Handle both dict and Pydantic object access for usage
- retry.py: Add jitter (±25%) to exponential backoff to prevent thundering herd
- test_retry.py: Add pytestmark = pytest.mark.unit per coding guidelines
- test_token_tracking.py: Add pytestmark = pytest.mark.unit per coding guidelines

src/middleware/retry.py CHANGED
@@ -1,6 +1,7 @@
1
  """Retry middleware for chat clients with exponential backoff."""
2
 
3
  import asyncio
 
4
  from collections.abc import Awaitable, Callable
5
 
6
  import structlog
@@ -52,9 +53,12 @@ class RetryMiddleware(ChatMiddleware):
52
  return False
53
 
54
  def _calculate_wait(self, attempt: int) -> float:
55
- """Calculate wait time with exponential backoff."""
56
  wait = self.min_wait * (2**attempt)
57
- return float(min(wait, self.max_wait))
 
 
 
58
 
59
  async def process(
60
  self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
 
1
  """Retry middleware for chat clients with exponential backoff."""
2
 
3
  import asyncio
4
+ import random
5
  from collections.abc import Awaitable, Callable
6
 
7
  import structlog
 
53
  return False
54
 
55
  def _calculate_wait(self, attempt: int) -> float:
56
+ """Calculate wait time with exponential backoff and jitter."""
57
  wait = self.min_wait * (2**attempt)
58
+ wait = min(wait, self.max_wait)
59
+ # Add jitter (±25%) to avoid thundering herd
60
+ jitter = wait * 0.25 * (2 * random.random() - 1)
61
+ return float(max(self.min_wait, wait + jitter))
62
 
63
  async def process(
64
  self, context: ChatContext, next: Callable[[ChatContext], Awaitable[None]]
src/middleware/token_tracking.py CHANGED
@@ -44,8 +44,19 @@ class TokenTrackingMiddleware(ChatMiddleware):
44
  usage = msg.metadata.get("usage")
45
 
46
  if usage:
47
- input_tokens = usage.get("input_tokens", 0) or usage.get("prompt_tokens", 0)
48
- output_tokens = usage.get("output_tokens", 0) or usage.get("completion_tokens", 0)
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  self.total_input_tokens += input_tokens
51
  self.total_output_tokens += output_tokens
 
44
  usage = msg.metadata.get("usage")
45
 
46
  if usage:
47
+ # Handle both dict-like and object attribute access
48
+ if hasattr(usage, "get"):
49
+ # Dict-like access
50
+ input_tokens = usage.get("input_tokens", 0) or usage.get("prompt_tokens", 0)
51
+ output_tokens = usage.get("output_tokens", 0) or usage.get("completion_tokens", 0)
52
+ else:
53
+ # Object attribute access (Pydantic models, etc.)
54
+ input_tokens = getattr(usage, "input_tokens", 0) or getattr(
55
+ usage, "prompt_tokens", 0
56
+ )
57
+ output_tokens = getattr(usage, "output_tokens", 0) or getattr(
58
+ usage, "completion_tokens", 0
59
+ )
60
 
61
  self.total_input_tokens += input_tokens
62
  self.total_output_tokens += output_tokens
tests/unit/middleware/test_retry.py CHANGED
@@ -4,6 +4,8 @@ import pytest
4
 
5
  from src.middleware.retry import RetryMiddleware
6
 
 
 
7
 
8
  @pytest.mark.asyncio
9
  async def test_retry_middleware_succeeds_first_try():
 
4
 
5
  from src.middleware.retry import RetryMiddleware
6
 
7
+ pytestmark = pytest.mark.unit
8
+
9
 
10
  @pytest.mark.asyncio
11
  async def test_retry_middleware_succeeds_first_try():
tests/unit/middleware/test_token_tracking.py CHANGED
@@ -4,6 +4,8 @@ import pytest
4
 
5
  from src.middleware.token_tracking import TokenTrackingMiddleware
6
 
 
 
7
 
8
  @pytest.mark.asyncio
9
  async def test_token_tracking_middleware_counts_tokens():
 
4
 
5
  from src.middleware.token_tracking import TokenTrackingMiddleware
6
 
7
+ pytestmark = pytest.mark.unit
8
+
9
 
10
  @pytest.mark.asyncio
11
  async def test_token_tracking_middleware_counts_tokens():