hashan-7 commited on
Commit
fc95190
·
verified ·
1 Parent(s): a9ad625

Update model_client.py

Browse files
Files changed (1) hide show
  1. model_client.py +15 -24
model_client.py CHANGED
@@ -1,9 +1,7 @@
1
  import requests
2
- from typing import Tuple
3
-
4
  from config import settings
5
 
6
-
7
  class ModelClient:
8
  def __init__(self):
9
  self.primary_model = settings.PRIMARY_CODE_MODEL
@@ -12,7 +10,7 @@ class ModelClient:
12
  self.temperature = settings.DEFAULT_TEMPERATURE
13
  self.top_p = settings.DEFAULT_TOP_P
14
 
15
- def _build_payload(self, prompt: str) -> dict:
16
  return {
17
  "inputs": prompt,
18
  "parameters": {
@@ -23,32 +21,33 @@ class ModelClient:
23
  "options": {
24
  "wait_for_model": True,
25
  "use_cache": False,
26
- },
27
  }
28
 
29
  def _extract_text(self, response_json) -> str:
30
  if isinstance(response_json, list) and len(response_json) > 0:
31
- first_item = response_json[0]
32
  if isinstance(first_item, dict) and "generated_text" in first_item:
33
  return str(first_item["generated_text"]).strip()
34
-
35
  if isinstance(response_json, dict):
36
  if "generated_text" in response_json:
37
  return str(response_json["generated_text"]).strip()
38
  if "error" in response_json:
39
  raise RuntimeError(str(response_json["error"]).strip())
40
-
41
  raise RuntimeError("Invalid model response format.")
42
 
43
  def _call_huggingface_model(self, prompt: str, model_name: str) -> str:
 
44
  api_url = f"https://api-inference.huggingface.co/models/{model_name}"
 
45
  headers = {}
46
-
47
  hf_token = getattr(settings, "HUGGINGFACE_API_TOKEN", "")
48
  if hf_token:
49
  headers["Authorization"] = f"Bearer {hf_token}"
50
 
51
- payload = self._build_payload(prompt)
52
 
53
  response = requests.post(
54
  api_url,
@@ -56,18 +55,12 @@ class ModelClient:
56
  json=payload,
57
  timeout=self.timeout,
58
  )
 
 
 
 
59
 
60
- try:
61
- response.raise_for_status()
62
- except requests.HTTPError:
63
- try:
64
- error_json = response.json()
65
- if isinstance(error_json, dict) and "error" in error_json:
66
- raise RuntimeError(str(error_json["error"]).strip())
67
- except ValueError:
68
- pass
69
- raise
70
-
71
  return self._extract_text(response.json())
72
 
73
  def generate(self, prompt: str) -> Tuple[str, str, bool]:
@@ -76,13 +69,11 @@ class ModelClient:
76
  return output, self.primary_model, False
77
  except Exception as primary_error:
78
  print(f"Primary model failed: {primary_error}")
79
-
80
  try:
81
  output = self._call_huggingface_model(prompt, self.fallback_model)
82
  return output, self.fallback_model, True
83
  except Exception as fallback_error:
84
  print(f"Fallback model failed: {fallback_error}")
85
- raise RuntimeError("Both primary and fallback models failed.")
86
-
87
 
88
  model_client = ModelClient()
 
1
  import requests
2
+ from typing import Optional, Tuple
 
3
  from config import settings
4
 
 
5
  class ModelClient:
6
  def __init__(self):
7
  self.primary_model = settings.PRIMARY_CODE_MODEL
 
10
  self.temperature = settings.DEFAULT_TEMPERATURE
11
  self.top_p = settings.DEFAULT_TOP_P
12
 
13
+ def _build_payload(self, prompt: str, model_name: str) -> dict:
14
  return {
15
  "inputs": prompt,
16
  "parameters": {
 
21
  "options": {
22
  "wait_for_model": True,
23
  "use_cache": False,
24
+ }
25
  }
26
 
27
  def _extract_text(self, response_json) -> str:
28
  if isinstance(response_json, list) and len(response_json) > 0:
29
+ first_item = response_json
30
  if isinstance(first_item, dict) and "generated_text" in first_item:
31
  return str(first_item["generated_text"]).strip()
32
+
33
  if isinstance(response_json, dict):
34
  if "generated_text" in response_json:
35
  return str(response_json["generated_text"]).strip()
36
  if "error" in response_json:
37
  raise RuntimeError(str(response_json["error"]).strip())
38
+
39
  raise RuntimeError("Invalid model response format.")
40
 
41
  def _call_huggingface_model(self, prompt: str, model_name: str) -> str:
42
+
43
  api_url = f"https://api-inference.huggingface.co/models/{model_name}"
44
+
45
  headers = {}
 
46
  hf_token = getattr(settings, "HUGGINGFACE_API_TOKEN", "")
47
  if hf_token:
48
  headers["Authorization"] = f"Bearer {hf_token}"
49
 
50
+ payload = self._build_payload(prompt, model_name)
51
 
52
  response = requests.post(
53
  api_url,
 
55
  json=payload,
56
  timeout=self.timeout,
57
  )
58
+
59
+
60
+ if response.status_code == 404 or "no longer supported" in response.text:
61
+ api_url = f"https://api-inference.huggingface.co/models/{model_name}"
62
 
63
+ response.raise_for_status()
 
 
 
 
 
 
 
 
 
 
64
  return self._extract_text(response.json())
65
 
66
  def generate(self, prompt: str) -> Tuple[str, str, bool]:
 
69
  return output, self.primary_model, False
70
  except Exception as primary_error:
71
  print(f"Primary model failed: {primary_error}")
 
72
  try:
73
  output = self._call_huggingface_model(prompt, self.fallback_model)
74
  return output, self.fallback_model, True
75
  except Exception as fallback_error:
76
  print(f"Fallback model failed: {fallback_error}")
77
+ raise RuntimeError(f"Both primary and fallback models failed.")
 
78
 
79
  model_client = ModelClient()