hashan-7 commited on
Commit
44b62ba
·
verified ·
1 Parent(s): 7f9a28e

Update code

Browse files
Files changed (1) hide show
  1. model_client.py +16 -8
model_client.py CHANGED
@@ -1,5 +1,5 @@
1
  import requests
2
- from typing import Optional, Tuple
3
 
4
  from config import settings
5
 
@@ -12,7 +12,7 @@ class ModelClient:
12
  self.temperature = settings.DEFAULT_TEMPERATURE
13
  self.top_p = settings.DEFAULT_TOP_P
14
 
15
- def _build_payload(self, prompt: str, model_name: str) -> dict:
16
  return {
17
  "inputs": prompt,
18
  "parameters": {
@@ -24,15 +24,13 @@ class ModelClient:
24
  "wait_for_model": True,
25
  "use_cache": False,
26
  },
27
- "model": model_name,
28
  }
29
 
30
  def _extract_text(self, response_json) -> str:
31
  if isinstance(response_json, list) and len(response_json) > 0:
32
  first_item = response_json[0]
33
- if isinstance(first_item, dict):
34
- if "generated_text" in first_item:
35
- return str(first_item["generated_text"]).strip()
36
 
37
  if isinstance(response_json, dict):
38
  if "generated_text" in response_json:
@@ -50,7 +48,7 @@ class ModelClient:
50
  if hf_token:
51
  headers["Authorization"] = f"Bearer {hf_token}"
52
 
53
- payload = self._build_payload(prompt, model_name)
54
 
55
  response = requests.post(
56
  api_url,
@@ -58,7 +56,17 @@ class ModelClient:
58
  json=payload,
59
  timeout=self.timeout,
60
  )
61
- response.raise_for_status()
 
 
 
 
 
 
 
 
 
 
62
 
63
  return self._extract_text(response.json())
64
 
 
1
  import requests
2
+ from typing import Tuple
3
 
4
  from config import settings
5
 
 
12
  self.temperature = settings.DEFAULT_TEMPERATURE
13
  self.top_p = settings.DEFAULT_TOP_P
14
 
15
+ def _build_payload(self, prompt: str) -> dict:
16
  return {
17
  "inputs": prompt,
18
  "parameters": {
 
24
  "wait_for_model": True,
25
  "use_cache": False,
26
  },
 
27
  }
28
 
29
  def _extract_text(self, response_json) -> str:
30
  if isinstance(response_json, list) and len(response_json) > 0:
31
  first_item = response_json[0]
32
+ if isinstance(first_item, dict) and "generated_text" in first_item:
33
+ return str(first_item["generated_text"]).strip()
 
34
 
35
  if isinstance(response_json, dict):
36
  if "generated_text" in response_json:
 
48
  if hf_token:
49
  headers["Authorization"] = f"Bearer {hf_token}"
50
 
51
+ payload = self._build_payload(prompt)
52
 
53
  response = requests.post(
54
  api_url,
 
56
  json=payload,
57
  timeout=self.timeout,
58
  )
59
+
60
+ try:
61
+ response.raise_for_status()
62
+ except requests.HTTPError:
63
+ try:
64
+ error_json = response.json()
65
+ if isinstance(error_json, dict) and "error" in error_json:
66
+ raise RuntimeError(str(error_json["error"]).strip())
67
+ except ValueError:
68
+ pass
69
+ raise
70
 
71
  return self._extract_text(response.json())
72