hashan-7 commited on
Commit
f8eaac1
·
verified ·
1 Parent(s): fc95190

fix the errors

Browse files
Files changed (1) hide show
  1. model_client.py +37 -53
model_client.py CHANGED
@@ -1,7 +1,10 @@
1
- import requests
2
- from typing import Optional, Tuple
 
 
3
  from config import settings
4
 
 
5
  class ModelClient:
6
  def __init__(self):
7
  self.primary_model = settings.PRIMARY_CODE_MODEL
@@ -9,71 +12,52 @@ class ModelClient:
9
  self.timeout = settings.MODEL_TIMEOUT_SECONDS
10
  self.temperature = settings.DEFAULT_TEMPERATURE
11
  self.top_p = settings.DEFAULT_TOP_P
 
12
 
13
- def _build_payload(self, prompt: str, model_name: str) -> dict:
14
- return {
15
- "inputs": prompt,
16
- "parameters": {
17
- "temperature": self.temperature,
18
- "top_p": self.top_p,
19
- "return_full_text": False,
20
- },
21
- "options": {
22
- "wait_for_model": True,
23
- "use_cache": False,
24
- }
25
- }
26
-
27
- def _extract_text(self, response_json) -> str:
28
- if isinstance(response_json, list) and len(response_json) > 0:
29
- first_item = response_json
30
- if isinstance(first_item, dict) and "generated_text" in first_item:
31
- return str(first_item["generated_text"]).strip()
32
-
33
- if isinstance(response_json, dict):
34
- if "generated_text" in response_json:
35
- return str(response_json["generated_text"]).strip()
36
- if "error" in response_json:
37
- raise RuntimeError(str(response_json["error"]).strip())
38
-
39
- raise RuntimeError("Invalid model response format.")
40
 
41
- def _call_huggingface_model(self, prompt: str, model_name: str) -> str:
42
-
43
- api_url = f"https://api-inference.huggingface.co/models/{model_name}"
44
-
45
- headers = {}
46
- hf_token = getattr(settings, "HUGGINGFACE_API_TOKEN", "")
47
- if hf_token:
48
- headers["Authorization"] = f"Bearer {hf_token}"
 
 
 
 
 
 
 
49
 
50
- payload = self._build_payload(prompt, model_name)
 
51
 
52
- response = requests.post(
53
- api_url,
54
- headers=headers,
55
- json=payload,
56
- timeout=self.timeout,
57
- )
58
-
59
-
60
- if response.status_code == 404 or "no longer supported" in response.text:
61
- api_url = f"https://api-inference.huggingface.co/models/{model_name}"
62
 
63
- response.raise_for_status()
64
- return self._extract_text(response.json())
65
 
66
  def generate(self, prompt: str) -> Tuple[str, str, bool]:
67
  try:
68
- output = self._call_huggingface_model(prompt, self.primary_model)
69
  return output, self.primary_model, False
70
  except Exception as primary_error:
71
  print(f"Primary model failed: {primary_error}")
 
72
  try:
73
- output = self._call_huggingface_model(prompt, self.fallback_model)
74
  return output, self.fallback_model, True
75
  except Exception as fallback_error:
76
  print(f"Fallback model failed: {fallback_error}")
77
- raise RuntimeError(f"Both primary and fallback models failed.")
 
78
 
79
  model_client = ModelClient()
 
1
+ from typing import Tuple
2
+
3
+ from huggingface_hub import InferenceClient
4
+
5
  from config import settings
6
 
7
+
8
  class ModelClient:
9
  def __init__(self):
10
  self.primary_model = settings.PRIMARY_CODE_MODEL
 
12
  self.timeout = settings.MODEL_TIMEOUT_SECONDS
13
  self.temperature = settings.DEFAULT_TEMPERATURE
14
  self.top_p = settings.DEFAULT_TOP_P
15
+ self.hf_token = getattr(settings, "HUGGINGFACE_API_TOKEN", "")
16
 
17
+ def _create_client(self) -> InferenceClient:
18
+ return InferenceClient(
19
+ api_key=self.hf_token if self.hf_token else None,
20
+ timeout=self.timeout,
21
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ def _call_model(self, prompt: str, model_name: str) -> str:
24
+ client = self._create_client()
25
+
26
+ response = client.chat.completions.create(
27
+ model=model_name,
28
+ messages=[
29
+ {
30
+ "role": "user",
31
+ "content": prompt,
32
+ }
33
+ ],
34
+ temperature=self.temperature,
35
+ top_p=self.top_p,
36
+ max_tokens=1200,
37
+ )
38
 
39
+ if not response or not response.choices:
40
+ raise RuntimeError("Empty response from model.")
41
 
42
+ message = response.choices[0].message
43
+ if not message or not message.content:
44
+ raise RuntimeError("Model returned no content.")
 
 
 
 
 
 
 
45
 
46
+ return str(message.content).strip()
 
47
 
48
  def generate(self, prompt: str) -> Tuple[str, str, bool]:
49
  try:
50
+ output = self._call_model(prompt, self.primary_model)
51
  return output, self.primary_model, False
52
  except Exception as primary_error:
53
  print(f"Primary model failed: {primary_error}")
54
+
55
  try:
56
+ output = self._call_model(prompt, self.fallback_model)
57
  return output, self.fallback_model, True
58
  except Exception as fallback_error:
59
  print(f"Fallback model failed: {fallback_error}")
60
+ raise RuntimeError("Both primary and fallback models failed.")
61
+
62
 
63
  model_client = ModelClient()