JatsTheAIGen commited on
Commit
edbd656
·
1 Parent(s): e440f24

router fixed v1

Browse files
Files changed (2) hide show
  1. llm_router.py +3 -3
  2. test_task_type_fix.py +155 -0
llm_router.py CHANGED
@@ -28,7 +28,7 @@ class LLMRouter:
28
  model_config = self._get_fallback_model(task_type)
29
  logger.info(f"Fallback model: {model_config['model_id']}")
30
 
31
- result = await self._call_hf_endpoint(model_config, prompt, **kwargs)
32
  logger.info(f"Inference complete for {task_type}")
33
  return result
34
 
@@ -69,7 +69,7 @@ class LLMRouter:
69
  }
70
  return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
71
 
72
- async def _call_hf_endpoint(self, model_config: dict, prompt: str, **kwargs):
73
  """
74
  Make actual call to Hugging Face Chat Completions API
75
  Uses the correct chat completions protocol
@@ -193,7 +193,7 @@ class LLMRouter:
193
  # Model is loading, retry with simpler model
194
  logger.warning(f"Model loading (503), trying fallback")
195
  fallback_config = self._get_fallback_model("response_synthesis")
196
- return await self._call_hf_endpoint(fallback_config, prompt, **kwargs)
197
  else:
198
  logger.error(f"HF API error: {response.status_code} - {response.text}")
199
  return None
 
28
  model_config = self._get_fallback_model(task_type)
29
  logger.info(f"Fallback model: {model_config['model_id']}")
30
 
31
+ result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs)
32
  logger.info(f"Inference complete for {task_type}")
33
  return result
34
 
 
69
  }
70
  return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
71
 
72
+ async def _call_hf_endpoint(self, model_config: dict, prompt: str, task_type: str, **kwargs):
73
  """
74
  Make actual call to Hugging Face Chat Completions API
75
  Uses the correct chat completions protocol
 
193
  # Model is loading, retry with simpler model
194
  logger.warning(f"Model loading (503), trying fallback")
195
  fallback_config = self._get_fallback_model("response_synthesis")
196
+ return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs)
197
  else:
198
  logger.error(f"HF API error: {response.status_code} - {response.text}")
199
  return None
test_task_type_fix.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script to verify the task_type fix in LLM router
4
+ This script tests that the NameError is resolved and the application works correctly.
5
+ """
6
+
7
+ import logging
8
+ import asyncio
9
+ import sys
10
+ import os
11
+
12
+ # Add the src directory to the path
13
+ sys.path.insert(0, '.')
14
+ sys.path.insert(0, 'src')
15
+
16
+ # Configure logging
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
20
+ handlers=[
21
+ logging.StreamHandler(),
22
+ logging.FileHandler('test_task_type_fix.log')
23
+ ]
24
+ )
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+ async def test_task_type_fix():
29
+ """Test that the task_type fix works correctly"""
30
+
31
+ logger.info("=" * 80)
32
+ logger.info("TESTING TASK_TYPE FIX IN LLM ROUTER")
33
+ logger.info("=" * 80)
34
+
35
+ try:
36
+ # Import the LLM router
37
+ from src.llm_router import LLMRouter
38
+ from src.config import settings
39
+
40
+ logger.info("✓ Successfully imported LLM router and config")
41
+
42
+ # Initialize LLM router with a test token
43
+ test_token = os.getenv('HF_TOKEN', 'test_token')
44
+ llm_router = LLMRouter(test_token)
45
+
46
+ logger.info("✓ LLM router initialized")
47
+
48
+ # Test different task types to ensure no NameError occurs
49
+ test_cases = [
50
+ ("intent_classification", "What is the user's intent?"),
51
+ ("response_synthesis", "Generate a response to: Hello"),
52
+ ("safety_check", "Is this content safe?"),
53
+ ("general_reasoning", "Solve this problem: 2+2"),
54
+ ]
55
+
56
+ for task_type, prompt in test_cases:
57
+ logger.info(f"Testing task_type: {task_type}")
58
+
59
+ try:
60
+ # This should not raise a NameError anymore
61
+ result = await llm_router.route_inference(
62
+ task_type=task_type,
63
+ prompt=prompt,
64
+ max_tokens=50,
65
+ temperature=0.7
66
+ )
67
+
68
+ logger.info(f"✓ Task '{task_type}' completed successfully")
69
+ if result:
70
+ logger.info(f" Result length: {len(result)} characters")
71
+ else:
72
+ logger.info(" Result: None (fallback used)")
73
+
74
+ except NameError as e:
75
+ if "task_type" in str(e):
76
+ logger.error(f"✗ NameError still exists for task '{task_type}': {e}")
77
+ return False
78
+ else:
79
+ logger.warning(f"Other NameError for task '{task_type}': {e}")
80
+ except Exception as e:
81
+ logger.info(f"Expected error for task '{task_type}': {e}")
82
+ # This is expected since we don't have real API access
83
+
84
+ logger.info("=" * 80)
85
+ logger.info("TASK_TYPE FIX VERIFICATION COMPLETED")
86
+ logger.info("=" * 80)
87
+ logger.info("✓ No NameError for 'task_type' variable")
88
+ logger.info("✓ All task types processed without variable errors")
89
+ logger.info("✓ LLM router method signature updated correctly")
90
+ logger.info("=" * 80)
91
+
92
+ return True
93
+
94
+ except ImportError as e:
95
+ logger.error(f"Import error: {e}")
96
+ logger.info("This is expected if dependencies are not available")
97
+ return False
98
+ except Exception as e:
99
+ logger.error(f"Test error: {e}")
100
+ logger.info("This indicates a problem with the fix")
101
+ return False
102
+
103
+ def test_method_signature():
104
+ """Test that the method signature is correct"""
105
+
106
+ logger.info("Testing method signature...")
107
+
108
+ try:
109
+ from src.llm_router import LLMRouter
110
+
111
+ # Check if the method signature is correct
112
+ import inspect
113
+
114
+ # Get the method signature
115
+ method = getattr(LLMRouter, '_call_hf_endpoint')
116
+ signature = inspect.signature(method)
117
+
118
+ # Check if task_type parameter exists
119
+ params = list(signature.parameters.keys())
120
+
121
+ logger.info(f"Method parameters: {params}")
122
+
123
+ if 'task_type' in params:
124
+ logger.info("✓ task_type parameter found in method signature")
125
+ return True
126
+ else:
127
+ logger.error("✗ task_type parameter missing from method signature")
128
+ return False
129
+
130
+ except Exception as e:
131
+ logger.error(f"Method signature test failed: {e}")
132
+ return False
133
+
134
+ if __name__ == "__main__":
135
+ print("Testing Task Type Fix in LLM Router")
136
+ print("=" * 50)
137
+
138
+ # Test method signature first
139
+ signature_ok = test_method_signature()
140
+
141
+ if signature_ok:
142
+ # Test the actual functionality
143
+ result = asyncio.run(test_task_type_fix())
144
+
145
+ if result:
146
+ print("\n✅ TASK_TYPE FIX VERIFICATION PASSED")
147
+ print("The NameError issue has been resolved.")
148
+ else:
149
+ print("\n❌ TASK_TYPE FIX VERIFICATION FAILED")
150
+ print("There may still be issues with the fix.")
151
+ else:
152
+ print("\n❌ METHOD SIGNATURE TEST FAILED")
153
+ print("The method signature is not correct.")
154
+
155
+ print("\nCheck 'test_task_type_fix.log' file for detailed logs.")