Commit
·
8f308fb
1
Parent(s):
edbd656
router fixed v2
Browse files- llm_router.py +52 -58
- src/llm_router.py +58 -64
- test_task_type_fix.py +1 -0
llm_router.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# llm_router.py
|
| 2 |
import logging
|
| 3 |
from models_config import LLM_CONFIG
|
| 4 |
|
|
@@ -28,6 +28,7 @@ class LLMRouter:
|
|
| 28 |
model_config = self._get_fallback_model(task_type)
|
| 29 |
logger.info(f"Fallback model: {model_config['model_id']}")
|
| 30 |
|
|
|
|
| 31 |
result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs)
|
| 32 |
logger.info(f"Inference complete for {task_type}")
|
| 33 |
return result
|
|
@@ -71,8 +72,10 @@ class LLMRouter:
|
|
| 71 |
|
| 72 |
async def _call_hf_endpoint(self, model_config: dict, prompt: str, task_type: str, **kwargs):
|
| 73 |
"""
|
| 74 |
-
Make actual call to Hugging Face Chat Completions API
|
| 75 |
Uses the correct chat completions protocol
|
|
|
|
|
|
|
| 76 |
"""
|
| 77 |
try:
|
| 78 |
import requests
|
|
@@ -88,6 +91,8 @@ class LLMRouter:
|
|
| 88 |
logger.info("LLM API REQUEST - COMPLETE PROMPT:")
|
| 89 |
logger.info("=" * 80)
|
| 90 |
logger.info(f"Model: {model_id}")
|
|
|
|
|
|
|
| 91 |
logger.info(f"Task Type: {task_type}")
|
| 92 |
logger.info(f"Prompt Length: {len(prompt)} characters")
|
| 93 |
logger.info("-" * 40)
|
|
@@ -98,76 +103,41 @@ class LLMRouter:
|
|
| 98 |
logger.info("END OF PROMPT")
|
| 99 |
logger.info("=" * 80)
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
}
|
| 105 |
-
|
| 106 |
-
# Prepare payload in chat completions format
|
| 107 |
-
# Extract the actual question from the prompt if it's in a structured format
|
| 108 |
-
user_message = prompt if "User Question:" not in prompt else prompt.split("User Question:")[1].split("\n")[0].strip()
|
| 109 |
|
| 110 |
payload = {
|
| 111 |
-
"model":
|
| 112 |
"messages": [
|
| 113 |
{
|
| 114 |
"role": "user",
|
| 115 |
-
"content":
|
| 116 |
}
|
| 117 |
],
|
| 118 |
-
"max_tokens":
|
| 119 |
-
"temperature":
|
| 120 |
-
"
|
| 121 |
}
|
| 122 |
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
logger.info(f"API URL: {api_url}")
|
| 128 |
-
logger.info(f"Model: {model_id}")
|
| 129 |
-
logger.info(f"Task Type: {task_type}")
|
| 130 |
-
logger.info(f"Max Tokens: {kwargs.get('max_tokens', 2000)}")
|
| 131 |
-
logger.info(f"Temperature: {kwargs.get('temperature', 0.7)}")
|
| 132 |
-
logger.info(f"Top P: {kwargs.get('top_p', 0.95)}")
|
| 133 |
-
logger.info(f"User Message Length: {len(user_message)} characters")
|
| 134 |
-
logger.info("-" * 40)
|
| 135 |
-
logger.info("API PAYLOAD:")
|
| 136 |
-
logger.info("-" * 40)
|
| 137 |
-
import json
|
| 138 |
-
logger.info(json.dumps(payload, indent=2))
|
| 139 |
-
logger.info("-" * 40)
|
| 140 |
-
logger.info("END OF API REQUEST")
|
| 141 |
-
logger.info("=" * 80)
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
| 145 |
|
| 146 |
if response.status_code == 200:
|
| 147 |
result = response.json()
|
|
|
|
| 148 |
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
logger.info("LLM API RESPONSE METADATA:")
|
| 152 |
-
logger.info("=" * 80)
|
| 153 |
-
logger.info(f"Status Code: {response.status_code}")
|
| 154 |
-
logger.info(f"Response Headers: {dict(response.headers)}")
|
| 155 |
-
logger.info(f"Response Size: {len(response.text)} characters")
|
| 156 |
-
logger.info("-" * 40)
|
| 157 |
-
logger.info("COMPLETE API RESPONSE JSON:")
|
| 158 |
-
logger.info("-" * 40)
|
| 159 |
-
logger.info(json.dumps(result, indent=2))
|
| 160 |
-
logger.info("-" * 40)
|
| 161 |
-
logger.info("END OF API RESPONSE METADATA")
|
| 162 |
-
logger.info("=" * 80)
|
| 163 |
-
|
| 164 |
-
# Handle chat completions response format
|
| 165 |
-
if "choices" in result and len(result["choices"]) > 0:
|
| 166 |
-
message = result["choices"][0].get("message", {})
|
| 167 |
-
generated_text = message.get("content", "")
|
| 168 |
|
| 169 |
-
|
| 170 |
-
if not generated_text or not isinstance(generated_text, str):
|
| 171 |
logger.warning(f"Empty or invalid response, using fallback")
|
| 172 |
return None
|
| 173 |
|
|
@@ -176,6 +146,8 @@ class LLMRouter:
|
|
| 176 |
logger.info("COMPLETE LLM API RESPONSE:")
|
| 177 |
logger.info("=" * 80)
|
| 178 |
logger.info(f"Model: {model_id}")
|
|
|
|
|
|
|
| 179 |
logger.info(f"Task Type: {task_type}")
|
| 180 |
logger.info(f"Response Length: {len(generated_text)} characters")
|
| 181 |
logger.info("-" * 40)
|
|
@@ -193,6 +165,8 @@ class LLMRouter:
|
|
| 193 |
# Model is loading, retry with simpler model
|
| 194 |
logger.warning(f"Model loading (503), trying fallback")
|
| 195 |
fallback_config = self._get_fallback_model("response_synthesis")
|
|
|
|
|
|
|
| 196 |
return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs)
|
| 197 |
else:
|
| 198 |
logger.error(f"HF API error: {response.status_code} - {response.text}")
|
|
@@ -204,4 +178,24 @@ class LLMRouter:
|
|
| 204 |
except Exception as e:
|
| 205 |
logger.error(f"Error calling HF endpoint: {e}", exc_info=True)
|
| 206 |
return None
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# llm_router.py - FIXED VERSION
|
| 2 |
import logging
|
| 3 |
from models_config import LLM_CONFIG
|
| 4 |
|
|
|
|
| 28 |
model_config = self._get_fallback_model(task_type)
|
| 29 |
logger.info(f"Fallback model: {model_config['model_id']}")
|
| 30 |
|
| 31 |
+
# FIXED: Ensure task_type is passed to the _call_hf_endpoint method
|
| 32 |
result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs)
|
| 33 |
logger.info(f"Inference complete for {task_type}")
|
| 34 |
return result
|
|
|
|
| 72 |
|
| 73 |
async def _call_hf_endpoint(self, model_config: dict, prompt: str, task_type: str, **kwargs):
|
| 74 |
"""
|
| 75 |
+
FIXED: Make actual call to Hugging Face Chat Completions API
|
| 76 |
Uses the correct chat completions protocol
|
| 77 |
+
|
| 78 |
+
IMPORTANT: task_type parameter is now properly included in the method signature
|
| 79 |
"""
|
| 80 |
try:
|
| 81 |
import requests
|
|
|
|
| 91 |
logger.info("LLM API REQUEST - COMPLETE PROMPT:")
|
| 92 |
logger.info("=" * 80)
|
| 93 |
logger.info(f"Model: {model_id}")
|
| 94 |
+
|
| 95 |
+
# FIXED: task_type is now properly available as a parameter
|
| 96 |
logger.info(f"Task Type: {task_type}")
|
| 97 |
logger.info(f"Prompt Length: {len(prompt)} characters")
|
| 98 |
logger.info("-" * 40)
|
|
|
|
| 103 |
logger.info("END OF PROMPT")
|
| 104 |
logger.info("=" * 80)
|
| 105 |
|
| 106 |
+
# Prepare the request payload
|
| 107 |
+
max_tokens = kwargs.get('max_tokens', 512)
|
| 108 |
+
temperature = kwargs.get('temperature', 0.7)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
payload = {
|
| 111 |
+
"model": model_id,
|
| 112 |
"messages": [
|
| 113 |
{
|
| 114 |
"role": "user",
|
| 115 |
+
"content": prompt
|
| 116 |
}
|
| 117 |
],
|
| 118 |
+
"max_tokens": max_tokens,
|
| 119 |
+
"temperature": temperature,
|
| 120 |
+
"stream": False
|
| 121 |
}
|
| 122 |
|
| 123 |
+
headers = {
|
| 124 |
+
"Authorization": f"Bearer {self.hf_token}",
|
| 125 |
+
"Content-Type": "application/json"
|
| 126 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
+
logger.info(f"Sending request to: {api_url}")
|
| 129 |
+
logger.debug(f"Payload: {payload}")
|
| 130 |
+
|
| 131 |
+
response = requests.post(api_url, json=payload, headers=headers, timeout=30)
|
| 132 |
|
| 133 |
if response.status_code == 200:
|
| 134 |
result = response.json()
|
| 135 |
+
logger.debug(f"Raw response: {result}")
|
| 136 |
|
| 137 |
+
if 'choices' in result and len(result['choices']) > 0:
|
| 138 |
+
generated_text = result['choices'][0]['message']['content']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
+
if not generated_text or generated_text.strip() == "":
|
|
|
|
| 141 |
logger.warning(f"Empty or invalid response, using fallback")
|
| 142 |
return None
|
| 143 |
|
|
|
|
| 146 |
logger.info("COMPLETE LLM API RESPONSE:")
|
| 147 |
logger.info("=" * 80)
|
| 148 |
logger.info(f"Model: {model_id}")
|
| 149 |
+
|
| 150 |
+
# FIXED: task_type is now properly available
|
| 151 |
logger.info(f"Task Type: {task_type}")
|
| 152 |
logger.info(f"Response Length: {len(generated_text)} characters")
|
| 153 |
logger.info("-" * 40)
|
|
|
|
| 165 |
# Model is loading, retry with simpler model
|
| 166 |
logger.warning(f"Model loading (503), trying fallback")
|
| 167 |
fallback_config = self._get_fallback_model("response_synthesis")
|
| 168 |
+
|
| 169 |
+
# FIXED: Ensure task_type is passed in recursive call
|
| 170 |
return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs)
|
| 171 |
else:
|
| 172 |
logger.error(f"HF API error: {response.status_code} - {response.text}")
|
|
|
|
| 178 |
except Exception as e:
|
| 179 |
logger.error(f"Error calling HF endpoint: {e}", exc_info=True)
|
| 180 |
return None
|
| 181 |
+
|
| 182 |
+
async def get_available_models(self):
|
| 183 |
+
"""
|
| 184 |
+
Get list of available models for testing
|
| 185 |
+
"""
|
| 186 |
+
return list(LLM_CONFIG["models"].keys())
|
| 187 |
+
|
| 188 |
+
async def health_check(self):
|
| 189 |
+
"""
|
| 190 |
+
Perform health check on all models
|
| 191 |
+
"""
|
| 192 |
+
health_status = {}
|
| 193 |
+
for model_name, model_config in LLM_CONFIG["models"].items():
|
| 194 |
+
model_id = model_config["model_id"]
|
| 195 |
+
is_healthy = await self._is_model_healthy(model_id)
|
| 196 |
+
health_status[model_name] = {
|
| 197 |
+
"model_id": model_id,
|
| 198 |
+
"healthy": is_healthy
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
return health_status
|
src/llm_router.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
# llm_router.py
|
| 2 |
import logging
|
| 3 |
from .models_config import LLM_CONFIG
|
| 4 |
|
|
@@ -28,7 +28,8 @@ class LLMRouter:
|
|
| 28 |
model_config = self._get_fallback_model(task_type)
|
| 29 |
logger.info(f"Fallback model: {model_config['model_id']}")
|
| 30 |
|
| 31 |
-
|
|
|
|
| 32 |
logger.info(f"Inference complete for {task_type}")
|
| 33 |
return result
|
| 34 |
|
|
@@ -69,18 +70,19 @@ class LLMRouter:
|
|
| 69 |
}
|
| 70 |
return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
|
| 71 |
|
| 72 |
-
async def _call_hf_endpoint(self, model_config: dict, prompt: str, **kwargs):
|
| 73 |
"""
|
| 74 |
-
Make actual call to Hugging Face Chat Completions API
|
| 75 |
Uses the correct chat completions protocol
|
|
|
|
|
|
|
| 76 |
"""
|
| 77 |
try:
|
| 78 |
import requests
|
| 79 |
|
| 80 |
model_id = model_config["model_id"]
|
| 81 |
-
is_chat_model = model_config.get("is_chat_model", True)
|
| 82 |
|
| 83 |
-
# Use the chat completions endpoint
|
| 84 |
api_url = "https://router.huggingface.co/v1/chat/completions"
|
| 85 |
|
| 86 |
logger.info(f"Calling HF Chat Completions API for model: {model_id}")
|
|
@@ -89,6 +91,8 @@ class LLMRouter:
|
|
| 89 |
logger.info("LLM API REQUEST - COMPLETE PROMPT:")
|
| 90 |
logger.info("=" * 80)
|
| 91 |
logger.info(f"Model: {model_id}")
|
|
|
|
|
|
|
| 92 |
logger.info(f"Task Type: {task_type}")
|
| 93 |
logger.info(f"Prompt Length: {len(prompt)} characters")
|
| 94 |
logger.info("-" * 40)
|
|
@@ -99,76 +103,41 @@ class LLMRouter:
|
|
| 99 |
logger.info("END OF PROMPT")
|
| 100 |
logger.info("=" * 80)
|
| 101 |
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
}
|
| 106 |
-
|
| 107 |
-
# Prepare payload in chat completions format
|
| 108 |
-
# Extract the actual question from the prompt if it's in a structured format
|
| 109 |
-
user_message = prompt if "User Question:" not in prompt else prompt.split("User Question:")[1].split("\n")[0].strip()
|
| 110 |
|
| 111 |
payload = {
|
| 112 |
-
"model":
|
| 113 |
"messages": [
|
| 114 |
{
|
| 115 |
"role": "user",
|
| 116 |
-
"content":
|
| 117 |
}
|
| 118 |
],
|
| 119 |
-
"max_tokens":
|
| 120 |
-
"temperature":
|
| 121 |
-
"
|
| 122 |
}
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
logger.info(f"API URL: {api_url}")
|
| 129 |
-
logger.info(f"Model: {model_id}")
|
| 130 |
-
logger.info(f"Task Type: {task_type}")
|
| 131 |
-
logger.info(f"Max Tokens: {kwargs.get('max_tokens', 2000)}")
|
| 132 |
-
logger.info(f"Temperature: {kwargs.get('temperature', 0.7)}")
|
| 133 |
-
logger.info(f"Top P: {kwargs.get('top_p', 0.95)}")
|
| 134 |
-
logger.info(f"User Message Length: {len(user_message)} characters")
|
| 135 |
-
logger.info("-" * 40)
|
| 136 |
-
logger.info("API PAYLOAD:")
|
| 137 |
-
logger.info("-" * 40)
|
| 138 |
-
import json
|
| 139 |
-
logger.info(json.dumps(payload, indent=2))
|
| 140 |
-
logger.info("-" * 40)
|
| 141 |
-
logger.info("END OF API REQUEST")
|
| 142 |
-
logger.info("=" * 80)
|
| 143 |
|
| 144 |
-
|
| 145 |
-
|
|
|
|
|
|
|
| 146 |
|
| 147 |
if response.status_code == 200:
|
| 148 |
result = response.json()
|
|
|
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
logger.info("LLM API RESPONSE METADATA:")
|
| 153 |
-
logger.info("=" * 80)
|
| 154 |
-
logger.info(f"Status Code: {response.status_code}")
|
| 155 |
-
logger.info(f"Response Headers: {dict(response.headers)}")
|
| 156 |
-
logger.info(f"Response Size: {len(response.text)} characters")
|
| 157 |
-
logger.info("-" * 40)
|
| 158 |
-
logger.info("COMPLETE API RESPONSE JSON:")
|
| 159 |
-
logger.info("-" * 40)
|
| 160 |
-
logger.info(json.dumps(result, indent=2))
|
| 161 |
-
logger.info("-" * 40)
|
| 162 |
-
logger.info("END OF API RESPONSE METADATA")
|
| 163 |
-
logger.info("=" * 80)
|
| 164 |
-
|
| 165 |
-
# Handle chat completions response format
|
| 166 |
-
if "choices" in result and len(result["choices"]) > 0:
|
| 167 |
-
message = result["choices"][0].get("message", {})
|
| 168 |
-
generated_text = message.get("content", "")
|
| 169 |
|
| 170 |
-
|
| 171 |
-
if not generated_text or not isinstance(generated_text, str):
|
| 172 |
logger.warning(f"Empty or invalid response, using fallback")
|
| 173 |
return None
|
| 174 |
|
|
@@ -177,6 +146,8 @@ class LLMRouter:
|
|
| 177 |
logger.info("COMPLETE LLM API RESPONSE:")
|
| 178 |
logger.info("=" * 80)
|
| 179 |
logger.info(f"Model: {model_id}")
|
|
|
|
|
|
|
| 180 |
logger.info(f"Task Type: {task_type}")
|
| 181 |
logger.info(f"Response Length: {len(generated_text)} characters")
|
| 182 |
logger.info("-" * 40)
|
|
@@ -194,14 +165,37 @@ class LLMRouter:
|
|
| 194 |
# Model is loading, retry with simpler model
|
| 195 |
logger.warning(f"Model loading (503), trying fallback")
|
| 196 |
fallback_config = self._get_fallback_model("response_synthesis")
|
| 197 |
-
|
|
|
|
|
|
|
| 198 |
else:
|
| 199 |
logger.error(f"HF API error: {response.status_code} - {response.text}")
|
| 200 |
return None
|
| 201 |
|
| 202 |
except ImportError:
|
| 203 |
-
logger.warning("requests library not available,
|
| 204 |
-
return
|
| 205 |
except Exception as e:
|
| 206 |
logger.error(f"Error calling HF endpoint: {e}", exc_info=True)
|
| 207 |
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# llm_router.py - FIXED VERSION
|
| 2 |
import logging
|
| 3 |
from .models_config import LLM_CONFIG
|
| 4 |
|
|
|
|
| 28 |
model_config = self._get_fallback_model(task_type)
|
| 29 |
logger.info(f"Fallback model: {model_config['model_id']}")
|
| 30 |
|
| 31 |
+
# FIXED: Ensure task_type is passed to the _call_hf_endpoint method
|
| 32 |
+
result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs)
|
| 33 |
logger.info(f"Inference complete for {task_type}")
|
| 34 |
return result
|
| 35 |
|
|
|
|
| 70 |
}
|
| 71 |
return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
|
| 72 |
|
| 73 |
+
async def _call_hf_endpoint(self, model_config: dict, prompt: str, task_type: str, **kwargs):
|
| 74 |
"""
|
| 75 |
+
FIXED: Make actual call to Hugging Face Chat Completions API
|
| 76 |
Uses the correct chat completions protocol
|
| 77 |
+
|
| 78 |
+
IMPORTANT: task_type parameter is now properly included in the method signature
|
| 79 |
"""
|
| 80 |
try:
|
| 81 |
import requests
|
| 82 |
|
| 83 |
model_id = model_config["model_id"]
|
|
|
|
| 84 |
|
| 85 |
+
# Use the chat completions endpoint
|
| 86 |
api_url = "https://router.huggingface.co/v1/chat/completions"
|
| 87 |
|
| 88 |
logger.info(f"Calling HF Chat Completions API for model: {model_id}")
|
|
|
|
| 91 |
logger.info("LLM API REQUEST - COMPLETE PROMPT:")
|
| 92 |
logger.info("=" * 80)
|
| 93 |
logger.info(f"Model: {model_id}")
|
| 94 |
+
|
| 95 |
+
# FIXED: task_type is now properly available as a parameter
|
| 96 |
logger.info(f"Task Type: {task_type}")
|
| 97 |
logger.info(f"Prompt Length: {len(prompt)} characters")
|
| 98 |
logger.info("-" * 40)
|
|
|
|
| 103 |
logger.info("END OF PROMPT")
|
| 104 |
logger.info("=" * 80)
|
| 105 |
|
| 106 |
+
# Prepare the request payload
|
| 107 |
+
max_tokens = kwargs.get('max_tokens', 512)
|
| 108 |
+
temperature = kwargs.get('temperature', 0.7)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
payload = {
|
| 111 |
+
"model": model_id,
|
| 112 |
"messages": [
|
| 113 |
{
|
| 114 |
"role": "user",
|
| 115 |
+
"content": prompt
|
| 116 |
}
|
| 117 |
],
|
| 118 |
+
"max_tokens": max_tokens,
|
| 119 |
+
"temperature": temperature,
|
| 120 |
+
"stream": False
|
| 121 |
}
|
| 122 |
|
| 123 |
+
headers = {
|
| 124 |
+
"Authorization": f"Bearer {self.hf_token}",
|
| 125 |
+
"Content-Type": "application/json"
|
| 126 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
+
logger.info(f"Sending request to: {api_url}")
|
| 129 |
+
logger.debug(f"Payload: {payload}")
|
| 130 |
+
|
| 131 |
+
response = requests.post(api_url, json=payload, headers=headers, timeout=30)
|
| 132 |
|
| 133 |
if response.status_code == 200:
|
| 134 |
result = response.json()
|
| 135 |
+
logger.debug(f"Raw response: {result}")
|
| 136 |
|
| 137 |
+
if 'choices' in result and len(result['choices']) > 0:
|
| 138 |
+
generated_text = result['choices'][0]['message']['content']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
|
| 140 |
+
if not generated_text or generated_text.strip() == "":
|
|
|
|
| 141 |
logger.warning(f"Empty or invalid response, using fallback")
|
| 142 |
return None
|
| 143 |
|
|
|
|
| 146 |
logger.info("COMPLETE LLM API RESPONSE:")
|
| 147 |
logger.info("=" * 80)
|
| 148 |
logger.info(f"Model: {model_id}")
|
| 149 |
+
|
| 150 |
+
# FIXED: task_type is now properly available
|
| 151 |
logger.info(f"Task Type: {task_type}")
|
| 152 |
logger.info(f"Response Length: {len(generated_text)} characters")
|
| 153 |
logger.info("-" * 40)
|
|
|
|
| 165 |
# Model is loading, retry with simpler model
|
| 166 |
logger.warning(f"Model loading (503), trying fallback")
|
| 167 |
fallback_config = self._get_fallback_model("response_synthesis")
|
| 168 |
+
|
| 169 |
+
# FIXED: Ensure task_type is passed in recursive call
|
| 170 |
+
return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs)
|
| 171 |
else:
|
| 172 |
logger.error(f"HF API error: {response.status_code} - {response.text}")
|
| 173 |
return None
|
| 174 |
|
| 175 |
except ImportError:
|
| 176 |
+
logger.warning("requests library not available, using mock response")
|
| 177 |
+
return f"[Mock] Response to: {prompt[:100]}..."
|
| 178 |
except Exception as e:
|
| 179 |
logger.error(f"Error calling HF endpoint: {e}", exc_info=True)
|
| 180 |
return None
|
| 181 |
+
|
| 182 |
+
async def get_available_models(self):
|
| 183 |
+
"""
|
| 184 |
+
Get list of available models for testing
|
| 185 |
+
"""
|
| 186 |
+
return list(LLM_CONFIG["models"].keys())
|
| 187 |
+
|
| 188 |
+
async def health_check(self):
|
| 189 |
+
"""
|
| 190 |
+
Perform health check on all models
|
| 191 |
+
"""
|
| 192 |
+
health_status = {}
|
| 193 |
+
for model_name, model_config in LLM_CONFIG["models"].items():
|
| 194 |
+
model_id = model_config["model_id"]
|
| 195 |
+
is_healthy = await self._is_model_healthy(model_id)
|
| 196 |
+
health_status[model_name] = {
|
| 197 |
+
"model_id": model_id,
|
| 198 |
+
"healthy": is_healthy
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
return health_status
|
test_task_type_fix.py
CHANGED
|
@@ -153,3 +153,4 @@ if __name__ == "__main__":
|
|
| 153 |
print("The method signature is not correct.")
|
| 154 |
|
| 155 |
print("\nCheck 'test_task_type_fix.log' file for detailed logs.")
|
|
|
|
|
|
| 153 |
print("The method signature is not correct.")
|
| 154 |
|
| 155 |
print("\nCheck 'test_task_type_fix.log' file for detailed logs.")
|
| 156 |
+
|