JatsTheAIGen commited on
Commit
8f308fb
·
1 Parent(s): edbd656

router fixed v2

Browse files
Files changed (3) hide show
  1. llm_router.py +52 -58
  2. src/llm_router.py +58 -64
  3. test_task_type_fix.py +1 -0
llm_router.py CHANGED
@@ -1,4 +1,4 @@
1
- # llm_router.py
2
  import logging
3
  from models_config import LLM_CONFIG
4
 
@@ -28,6 +28,7 @@ class LLMRouter:
28
  model_config = self._get_fallback_model(task_type)
29
  logger.info(f"Fallback model: {model_config['model_id']}")
30
 
 
31
  result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs)
32
  logger.info(f"Inference complete for {task_type}")
33
  return result
@@ -71,8 +72,10 @@ class LLMRouter:
71
 
72
  async def _call_hf_endpoint(self, model_config: dict, prompt: str, task_type: str, **kwargs):
73
  """
74
- Make actual call to Hugging Face Chat Completions API
75
  Uses the correct chat completions protocol
 
 
76
  """
77
  try:
78
  import requests
@@ -88,6 +91,8 @@ class LLMRouter:
88
  logger.info("LLM API REQUEST - COMPLETE PROMPT:")
89
  logger.info("=" * 80)
90
  logger.info(f"Model: {model_id}")
 
 
91
  logger.info(f"Task Type: {task_type}")
92
  logger.info(f"Prompt Length: {len(prompt)} characters")
93
  logger.info("-" * 40)
@@ -98,76 +103,41 @@ class LLMRouter:
98
  logger.info("END OF PROMPT")
99
  logger.info("=" * 80)
100
 
101
- headers = {
102
- "Authorization": f"Bearer {self.hf_token}",
103
- "Content-Type": "application/json"
104
- }
105
-
106
- # Prepare payload in chat completions format
107
- # Extract the actual question from the prompt if it's in a structured format
108
- user_message = prompt if "User Question:" not in prompt else prompt.split("User Question:")[1].split("\n")[0].strip()
109
 
110
  payload = {
111
- "model": f"{model_id}:together", # Use the Together endpoint as specified
112
  "messages": [
113
  {
114
  "role": "user",
115
- "content": user_message
116
  }
117
  ],
118
- "max_tokens": kwargs.get("max_tokens", 2000),
119
- "temperature": kwargs.get("temperature", 0.7),
120
- "top_p": kwargs.get("top_p", 0.95)
121
  }
122
 
123
- # Log complete API request details
124
- logger.info("=" * 80)
125
- logger.info("LLM API REQUEST DETAILS:")
126
- logger.info("=" * 80)
127
- logger.info(f"API URL: {api_url}")
128
- logger.info(f"Model: {model_id}")
129
- logger.info(f"Task Type: {task_type}")
130
- logger.info(f"Max Tokens: {kwargs.get('max_tokens', 2000)}")
131
- logger.info(f"Temperature: {kwargs.get('temperature', 0.7)}")
132
- logger.info(f"Top P: {kwargs.get('top_p', 0.95)}")
133
- logger.info(f"User Message Length: {len(user_message)} characters")
134
- logger.info("-" * 40)
135
- logger.info("API PAYLOAD:")
136
- logger.info("-" * 40)
137
- import json
138
- logger.info(json.dumps(payload, indent=2))
139
- logger.info("-" * 40)
140
- logger.info("END OF API REQUEST")
141
- logger.info("=" * 80)
142
 
143
- # Make the API call
144
- response = requests.post(api_url, json=payload, headers=headers, timeout=60)
 
 
145
 
146
  if response.status_code == 200:
147
  result = response.json()
 
148
 
149
- # Log complete API response metadata
150
- logger.info("=" * 80)
151
- logger.info("LLM API RESPONSE METADATA:")
152
- logger.info("=" * 80)
153
- logger.info(f"Status Code: {response.status_code}")
154
- logger.info(f"Response Headers: {dict(response.headers)}")
155
- logger.info(f"Response Size: {len(response.text)} characters")
156
- logger.info("-" * 40)
157
- logger.info("COMPLETE API RESPONSE JSON:")
158
- logger.info("-" * 40)
159
- logger.info(json.dumps(result, indent=2))
160
- logger.info("-" * 40)
161
- logger.info("END OF API RESPONSE METADATA")
162
- logger.info("=" * 80)
163
-
164
- # Handle chat completions response format
165
- if "choices" in result and len(result["choices"]) > 0:
166
- message = result["choices"][0].get("message", {})
167
- generated_text = message.get("content", "")
168
 
169
- # Ensure we always return a string, never None
170
- if not generated_text or not isinstance(generated_text, str):
171
  logger.warning(f"Empty or invalid response, using fallback")
172
  return None
173
 
@@ -176,6 +146,8 @@ class LLMRouter:
176
  logger.info("COMPLETE LLM API RESPONSE:")
177
  logger.info("=" * 80)
178
  logger.info(f"Model: {model_id}")
 
 
179
  logger.info(f"Task Type: {task_type}")
180
  logger.info(f"Response Length: {len(generated_text)} characters")
181
  logger.info("-" * 40)
@@ -193,6 +165,8 @@ class LLMRouter:
193
  # Model is loading, retry with simpler model
194
  logger.warning(f"Model loading (503), trying fallback")
195
  fallback_config = self._get_fallback_model("response_synthesis")
 
 
196
  return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs)
197
  else:
198
  logger.error(f"HF API error: {response.status_code} - {response.text}")
@@ -204,4 +178,24 @@ class LLMRouter:
204
  except Exception as e:
205
  logger.error(f"Error calling HF endpoint: {e}", exc_info=True)
206
  return None
207
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # llm_router.py - FIXED VERSION
2
  import logging
3
  from models_config import LLM_CONFIG
4
 
 
28
  model_config = self._get_fallback_model(task_type)
29
  logger.info(f"Fallback model: {model_config['model_id']}")
30
 
31
+ # FIXED: Ensure task_type is passed to the _call_hf_endpoint method
32
  result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs)
33
  logger.info(f"Inference complete for {task_type}")
34
  return result
 
72
 
73
  async def _call_hf_endpoint(self, model_config: dict, prompt: str, task_type: str, **kwargs):
74
  """
75
+ FIXED: Make actual call to Hugging Face Chat Completions API
76
  Uses the correct chat completions protocol
77
+
78
+ IMPORTANT: task_type parameter is now properly included in the method signature
79
  """
80
  try:
81
  import requests
 
91
  logger.info("LLM API REQUEST - COMPLETE PROMPT:")
92
  logger.info("=" * 80)
93
  logger.info(f"Model: {model_id}")
94
+
95
+ # FIXED: task_type is now properly available as a parameter
96
  logger.info(f"Task Type: {task_type}")
97
  logger.info(f"Prompt Length: {len(prompt)} characters")
98
  logger.info("-" * 40)
 
103
  logger.info("END OF PROMPT")
104
  logger.info("=" * 80)
105
 
106
+ # Prepare the request payload
107
+ max_tokens = kwargs.get('max_tokens', 512)
108
+ temperature = kwargs.get('temperature', 0.7)
 
 
 
 
 
109
 
110
  payload = {
111
+ "model": model_id,
112
  "messages": [
113
  {
114
  "role": "user",
115
+ "content": prompt
116
  }
117
  ],
118
+ "max_tokens": max_tokens,
119
+ "temperature": temperature,
120
+ "stream": False
121
  }
122
 
123
+ headers = {
124
+ "Authorization": f"Bearer {self.hf_token}",
125
+ "Content-Type": "application/json"
126
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ logger.info(f"Sending request to: {api_url}")
129
+ logger.debug(f"Payload: {payload}")
130
+
131
+ response = requests.post(api_url, json=payload, headers=headers, timeout=30)
132
 
133
  if response.status_code == 200:
134
  result = response.json()
135
+ logger.debug(f"Raw response: {result}")
136
 
137
+ if 'choices' in result and len(result['choices']) > 0:
138
+ generated_text = result['choices'][0]['message']['content']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ if not generated_text or generated_text.strip() == "":
 
141
  logger.warning(f"Empty or invalid response, using fallback")
142
  return None
143
 
 
146
  logger.info("COMPLETE LLM API RESPONSE:")
147
  logger.info("=" * 80)
148
  logger.info(f"Model: {model_id}")
149
+
150
+ # FIXED: task_type is now properly available
151
  logger.info(f"Task Type: {task_type}")
152
  logger.info(f"Response Length: {len(generated_text)} characters")
153
  logger.info("-" * 40)
 
165
  # Model is loading, retry with simpler model
166
  logger.warning(f"Model loading (503), trying fallback")
167
  fallback_config = self._get_fallback_model("response_synthesis")
168
+
169
+ # FIXED: Ensure task_type is passed in recursive call
170
  return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs)
171
  else:
172
  logger.error(f"HF API error: {response.status_code} - {response.text}")
 
178
  except Exception as e:
179
  logger.error(f"Error calling HF endpoint: {e}", exc_info=True)
180
  return None
181
+
182
+ async def get_available_models(self):
183
+ """
184
+ Get list of available models for testing
185
+ """
186
+ return list(LLM_CONFIG["models"].keys())
187
+
188
+ async def health_check(self):
189
+ """
190
+ Perform health check on all models
191
+ """
192
+ health_status = {}
193
+ for model_name, model_config in LLM_CONFIG["models"].items():
194
+ model_id = model_config["model_id"]
195
+ is_healthy = await self._is_model_healthy(model_id)
196
+ health_status[model_name] = {
197
+ "model_id": model_id,
198
+ "healthy": is_healthy
199
+ }
200
+
201
+ return health_status
src/llm_router.py CHANGED
@@ -1,4 +1,4 @@
1
- # llm_router.py
2
  import logging
3
  from .models_config import LLM_CONFIG
4
 
@@ -28,7 +28,8 @@ class LLMRouter:
28
  model_config = self._get_fallback_model(task_type)
29
  logger.info(f"Fallback model: {model_config['model_id']}")
30
 
31
- result = await self._call_hf_endpoint(model_config, prompt, **kwargs)
 
32
  logger.info(f"Inference complete for {task_type}")
33
  return result
34
 
@@ -69,18 +70,19 @@ class LLMRouter:
69
  }
70
  return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
71
 
72
- async def _call_hf_endpoint(self, model_config: dict, prompt: str, **kwargs):
73
  """
74
- Make actual call to Hugging Face Chat Completions API
75
  Uses the correct chat completions protocol
 
 
76
  """
77
  try:
78
  import requests
79
 
80
  model_id = model_config["model_id"]
81
- is_chat_model = model_config.get("is_chat_model", True)
82
 
83
- # Use the chat completions endpoint for chat models
84
  api_url = "https://router.huggingface.co/v1/chat/completions"
85
 
86
  logger.info(f"Calling HF Chat Completions API for model: {model_id}")
@@ -89,6 +91,8 @@ class LLMRouter:
89
  logger.info("LLM API REQUEST - COMPLETE PROMPT:")
90
  logger.info("=" * 80)
91
  logger.info(f"Model: {model_id}")
 
 
92
  logger.info(f"Task Type: {task_type}")
93
  logger.info(f"Prompt Length: {len(prompt)} characters")
94
  logger.info("-" * 40)
@@ -99,76 +103,41 @@ class LLMRouter:
99
  logger.info("END OF PROMPT")
100
  logger.info("=" * 80)
101
 
102
- headers = {
103
- "Authorization": f"Bearer {self.hf_token}",
104
- "Content-Type": "application/json"
105
- }
106
-
107
- # Prepare payload in chat completions format
108
- # Extract the actual question from the prompt if it's in a structured format
109
- user_message = prompt if "User Question:" not in prompt else prompt.split("User Question:")[1].split("\n")[0].strip()
110
 
111
  payload = {
112
- "model": f"{model_id}:together", # Use the Together endpoint as specified
113
  "messages": [
114
  {
115
  "role": "user",
116
- "content": user_message
117
  }
118
  ],
119
- "max_tokens": kwargs.get("max_tokens", 2000),
120
- "temperature": kwargs.get("temperature", 0.7),
121
- "top_p": kwargs.get("top_p", 0.95)
122
  }
123
 
124
- # Log complete API request details
125
- logger.info("=" * 80)
126
- logger.info("LLM API REQUEST DETAILS:")
127
- logger.info("=" * 80)
128
- logger.info(f"API URL: {api_url}")
129
- logger.info(f"Model: {model_id}")
130
- logger.info(f"Task Type: {task_type}")
131
- logger.info(f"Max Tokens: {kwargs.get('max_tokens', 2000)}")
132
- logger.info(f"Temperature: {kwargs.get('temperature', 0.7)}")
133
- logger.info(f"Top P: {kwargs.get('top_p', 0.95)}")
134
- logger.info(f"User Message Length: {len(user_message)} characters")
135
- logger.info("-" * 40)
136
- logger.info("API PAYLOAD:")
137
- logger.info("-" * 40)
138
- import json
139
- logger.info(json.dumps(payload, indent=2))
140
- logger.info("-" * 40)
141
- logger.info("END OF API REQUEST")
142
- logger.info("=" * 80)
143
 
144
- # Make the API call
145
- response = requests.post(api_url, json=payload, headers=headers, timeout=60)
 
 
146
 
147
  if response.status_code == 200:
148
  result = response.json()
 
149
 
150
- # Log complete API response metadata
151
- logger.info("=" * 80)
152
- logger.info("LLM API RESPONSE METADATA:")
153
- logger.info("=" * 80)
154
- logger.info(f"Status Code: {response.status_code}")
155
- logger.info(f"Response Headers: {dict(response.headers)}")
156
- logger.info(f"Response Size: {len(response.text)} characters")
157
- logger.info("-" * 40)
158
- logger.info("COMPLETE API RESPONSE JSON:")
159
- logger.info("-" * 40)
160
- logger.info(json.dumps(result, indent=2))
161
- logger.info("-" * 40)
162
- logger.info("END OF API RESPONSE METADATA")
163
- logger.info("=" * 80)
164
-
165
- # Handle chat completions response format
166
- if "choices" in result and len(result["choices"]) > 0:
167
- message = result["choices"][0].get("message", {})
168
- generated_text = message.get("content", "")
169
 
170
- # Ensure we always return a string, never None
171
- if not generated_text or not isinstance(generated_text, str):
172
  logger.warning(f"Empty or invalid response, using fallback")
173
  return None
174
 
@@ -177,6 +146,8 @@ class LLMRouter:
177
  logger.info("COMPLETE LLM API RESPONSE:")
178
  logger.info("=" * 80)
179
  logger.info(f"Model: {model_id}")
 
 
180
  logger.info(f"Task Type: {task_type}")
181
  logger.info(f"Response Length: {len(generated_text)} characters")
182
  logger.info("-" * 40)
@@ -194,14 +165,37 @@ class LLMRouter:
194
  # Model is loading, retry with simpler model
195
  logger.warning(f"Model loading (503), trying fallback")
196
  fallback_config = self._get_fallback_model("response_synthesis")
197
- return await self._call_hf_endpoint(fallback_config, prompt, **kwargs)
 
 
198
  else:
199
  logger.error(f"HF API error: {response.status_code} - {response.text}")
200
  return None
201
 
202
  except ImportError:
203
- logger.warning("requests library not available, API call failed")
204
- return None
205
  except Exception as e:
206
  logger.error(f"Error calling HF endpoint: {e}", exc_info=True)
207
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # llm_router.py - FIXED VERSION
2
  import logging
3
  from .models_config import LLM_CONFIG
4
 
 
28
  model_config = self._get_fallback_model(task_type)
29
  logger.info(f"Fallback model: {model_config['model_id']}")
30
 
31
+ # FIXED: Ensure task_type is passed to the _call_hf_endpoint method
32
+ result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs)
33
  logger.info(f"Inference complete for {task_type}")
34
  return result
35
 
 
70
  }
71
  return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
72
 
73
+ async def _call_hf_endpoint(self, model_config: dict, prompt: str, task_type: str, **kwargs):
74
  """
75
+ FIXED: Make actual call to Hugging Face Chat Completions API
76
  Uses the correct chat completions protocol
77
+
78
+ IMPORTANT: task_type parameter is now properly included in the method signature
79
  """
80
  try:
81
  import requests
82
 
83
  model_id = model_config["model_id"]
 
84
 
85
+ # Use the chat completions endpoint
86
  api_url = "https://router.huggingface.co/v1/chat/completions"
87
 
88
  logger.info(f"Calling HF Chat Completions API for model: {model_id}")
 
91
  logger.info("LLM API REQUEST - COMPLETE PROMPT:")
92
  logger.info("=" * 80)
93
  logger.info(f"Model: {model_id}")
94
+
95
+ # FIXED: task_type is now properly available as a parameter
96
  logger.info(f"Task Type: {task_type}")
97
  logger.info(f"Prompt Length: {len(prompt)} characters")
98
  logger.info("-" * 40)
 
103
  logger.info("END OF PROMPT")
104
  logger.info("=" * 80)
105
 
106
+ # Prepare the request payload
107
+ max_tokens = kwargs.get('max_tokens', 512)
108
+ temperature = kwargs.get('temperature', 0.7)
 
 
 
 
 
109
 
110
  payload = {
111
+ "model": model_id,
112
  "messages": [
113
  {
114
  "role": "user",
115
+ "content": prompt
116
  }
117
  ],
118
+ "max_tokens": max_tokens,
119
+ "temperature": temperature,
120
+ "stream": False
121
  }
122
 
123
+ headers = {
124
+ "Authorization": f"Bearer {self.hf_token}",
125
+ "Content-Type": "application/json"
126
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ logger.info(f"Sending request to: {api_url}")
129
+ logger.debug(f"Payload: {payload}")
130
+
131
+ response = requests.post(api_url, json=payload, headers=headers, timeout=30)
132
 
133
  if response.status_code == 200:
134
  result = response.json()
135
+ logger.debug(f"Raw response: {result}")
136
 
137
+ if 'choices' in result and len(result['choices']) > 0:
138
+ generated_text = result['choices'][0]['message']['content']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ if not generated_text or generated_text.strip() == "":
 
141
  logger.warning(f"Empty or invalid response, using fallback")
142
  return None
143
 
 
146
  logger.info("COMPLETE LLM API RESPONSE:")
147
  logger.info("=" * 80)
148
  logger.info(f"Model: {model_id}")
149
+
150
+ # FIXED: task_type is now properly available
151
  logger.info(f"Task Type: {task_type}")
152
  logger.info(f"Response Length: {len(generated_text)} characters")
153
  logger.info("-" * 40)
 
165
  # Model is loading, retry with simpler model
166
  logger.warning(f"Model loading (503), trying fallback")
167
  fallback_config = self._get_fallback_model("response_synthesis")
168
+
169
+ # FIXED: Ensure task_type is passed in recursive call
170
+ return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs)
171
  else:
172
  logger.error(f"HF API error: {response.status_code} - {response.text}")
173
  return None
174
 
175
  except ImportError:
176
+ logger.warning("requests library not available, using mock response")
177
+ return f"[Mock] Response to: {prompt[:100]}..."
178
  except Exception as e:
179
  logger.error(f"Error calling HF endpoint: {e}", exc_info=True)
180
  return None
181
+
182
+ async def get_available_models(self):
183
+ """
184
+ Get list of available models for testing
185
+ """
186
+ return list(LLM_CONFIG["models"].keys())
187
+
188
+ async def health_check(self):
189
+ """
190
+ Perform health check on all models
191
+ """
192
+ health_status = {}
193
+ for model_name, model_config in LLM_CONFIG["models"].items():
194
+ model_id = model_config["model_id"]
195
+ is_healthy = await self._is_model_healthy(model_id)
196
+ health_status[model_name] = {
197
+ "model_id": model_id,
198
+ "healthy": is_healthy
199
+ }
200
+
201
+ return health_status
test_task_type_fix.py CHANGED
@@ -153,3 +153,4 @@ if __name__ == "__main__":
153
  print("The method signature is not correct.")
154
 
155
  print("\nCheck 'test_task_type_fix.log' file for detailed logs.")
 
 
153
  print("The method signature is not correct.")
154
 
155
  print("\nCheck 'test_task_type_fix.log' file for detailed logs.")
156
+