Update code
Browse files- model_client.py +32 -12
model_client.py
CHANGED
|
@@ -12,6 +12,7 @@ class ModelClient:
|
|
| 12 |
self.timeout = settings.MODEL_TIMEOUT_SECONDS
|
| 13 |
self.temperature = settings.DEFAULT_TEMPERATURE
|
| 14 |
self.top_p = settings.DEFAULT_TOP_P
|
|
|
|
| 15 |
self.hf_token = getattr(settings, "HUGGINGFACE_API_TOKEN", "")
|
| 16 |
|
| 17 |
def _create_client(self) -> InferenceClient:
|
|
@@ -20,7 +21,31 @@ class ModelClient:
|
|
| 20 |
timeout=self.timeout,
|
| 21 |
)
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
def _call_model(self, prompt: str, model_name: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
client = self._create_client()
|
| 25 |
|
| 26 |
response = client.chat.completions.create(
|
|
@@ -28,35 +53,30 @@ class ModelClient:
|
|
| 28 |
messages=[
|
| 29 |
{
|
| 30 |
"role": "user",
|
| 31 |
-
"content":
|
| 32 |
}
|
| 33 |
],
|
| 34 |
temperature=self.temperature,
|
| 35 |
top_p=self.top_p,
|
| 36 |
-
max_tokens=
|
| 37 |
)
|
| 38 |
|
| 39 |
-
|
| 40 |
-
raise RuntimeError("Empty response from model.")
|
| 41 |
-
|
| 42 |
-
message = response.choices[0].message
|
| 43 |
-
if not message or not message.content:
|
| 44 |
-
raise RuntimeError("Model returned no content.")
|
| 45 |
-
|
| 46 |
-
return str(message.content).strip()
|
| 47 |
|
| 48 |
def generate(self, prompt: str) -> Tuple[str, str, bool]:
|
| 49 |
try:
|
| 50 |
output = self._call_model(prompt, self.primary_model)
|
| 51 |
return output, self.primary_model, False
|
|
|
|
| 52 |
except Exception as primary_error:
|
| 53 |
-
print(f"Primary model failed: {primary_error}")
|
| 54 |
|
| 55 |
try:
|
| 56 |
output = self._call_model(prompt, self.fallback_model)
|
| 57 |
return output, self.fallback_model, True
|
|
|
|
| 58 |
except Exception as fallback_error:
|
| 59 |
-
print(f"Fallback model failed: {fallback_error}")
|
| 60 |
raise RuntimeError("Both primary and fallback models failed.")
|
| 61 |
|
| 62 |
|
|
|
|
| 12 |
self.timeout = settings.MODEL_TIMEOUT_SECONDS
|
| 13 |
self.temperature = settings.DEFAULT_TEMPERATURE
|
| 14 |
self.top_p = settings.DEFAULT_TOP_P
|
| 15 |
+
self.max_tokens = settings.MAX_OUTPUT_TOKENS
|
| 16 |
self.hf_token = getattr(settings, "HUGGINGFACE_API_TOKEN", "")
|
| 17 |
|
| 18 |
def _create_client(self) -> InferenceClient:
|
|
|
|
| 21 |
timeout=self.timeout,
|
| 22 |
)
|
| 23 |
|
| 24 |
+
def _extract_content(self, response) -> str:
|
| 25 |
+
if not response or not getattr(response, "choices", None):
|
| 26 |
+
raise RuntimeError("Empty response from model.")
|
| 27 |
+
|
| 28 |
+
first_choice = response.choices[0]
|
| 29 |
+
if not first_choice or not getattr(first_choice, "message", None):
|
| 30 |
+
raise RuntimeError("Model returned an invalid response structure.")
|
| 31 |
+
|
| 32 |
+
message = first_choice.message
|
| 33 |
+
content = getattr(message, "content", None)
|
| 34 |
+
|
| 35 |
+
if content is None:
|
| 36 |
+
raise RuntimeError("Model returned no content.")
|
| 37 |
+
|
| 38 |
+
cleaned = str(content).strip()
|
| 39 |
+
if not cleaned:
|
| 40 |
+
raise RuntimeError("Model returned empty content.")
|
| 41 |
+
|
| 42 |
+
return cleaned
|
| 43 |
+
|
| 44 |
def _call_model(self, prompt: str, model_name: str) -> str:
|
| 45 |
+
cleaned_prompt = str(prompt or "").strip()
|
| 46 |
+
if not cleaned_prompt:
|
| 47 |
+
raise RuntimeError("Prompt is empty.")
|
| 48 |
+
|
| 49 |
client = self._create_client()
|
| 50 |
|
| 51 |
response = client.chat.completions.create(
|
|
|
|
| 53 |
messages=[
|
| 54 |
{
|
| 55 |
"role": "user",
|
| 56 |
+
"content": cleaned_prompt,
|
| 57 |
}
|
| 58 |
],
|
| 59 |
temperature=self.temperature,
|
| 60 |
top_p=self.top_p,
|
| 61 |
+
max_tokens=self.max_tokens,
|
| 62 |
)
|
| 63 |
|
| 64 |
+
return self._extract_content(response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
def generate(self, prompt: str) -> Tuple[str, str, bool]:
|
| 67 |
try:
|
| 68 |
output = self._call_model(prompt, self.primary_model)
|
| 69 |
return output, self.primary_model, False
|
| 70 |
+
|
| 71 |
except Exception as primary_error:
|
| 72 |
+
print(f"Primary model failed: {primary_error}", flush=True)
|
| 73 |
|
| 74 |
try:
|
| 75 |
output = self._call_model(prompt, self.fallback_model)
|
| 76 |
return output, self.fallback_model, True
|
| 77 |
+
|
| 78 |
except Exception as fallback_error:
|
| 79 |
+
print(f"Fallback model failed: {fallback_error}", flush=True)
|
| 80 |
raise RuntimeError("Both primary and fallback models failed.")
|
| 81 |
|
| 82 |
|