hashan-7 commited on
Commit
895ed04
·
verified ·
1 Parent(s): 3f6ba20

add the code

Browse files
Files changed (1) hide show
  1. model_client.py +80 -0
model_client.py CHANGED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from typing import Optional, Tuple
3
+
4
+ from config import settings
5
+
6
+
7
+ class ModelClient:
8
+ def __init__(self):
9
+ self.primary_model = settings.PRIMARY_CODE_MODEL
10
+ self.fallback_model = settings.FALLBACK_CODE_MODEL
11
+ self.timeout = settings.MODEL_TIMEOUT_SECONDS
12
+ self.temperature = settings.DEFAULT_TEMPERATURE
13
+ self.top_p = settings.DEFAULT_TOP_P
14
+
15
+ def _build_payload(self, prompt: str, model_name: str) -> dict:
16
+ return {
17
+ "inputs": prompt,
18
+ "parameters": {
19
+ "temperature": self.temperature,
20
+ "top_p": self.top_p,
21
+ "return_full_text": False,
22
+ },
23
+ "options": {
24
+ "wait_for_model": True,
25
+ "use_cache": False,
26
+ },
27
+ "model": model_name,
28
+ }
29
+
30
+ def _extract_text(self, response_json) -> str:
31
+ if isinstance(response_json, list) and len(response_json) > 0:
32
+ first_item = response_json[0]
33
+ if isinstance(first_item, dict):
34
+ if "generated_text" in first_item:
35
+ return str(first_item["generated_text"]).strip()
36
+
37
+ if isinstance(response_json, dict):
38
+ if "generated_text" in response_json:
39
+ return str(response_json["generated_text"]).strip()
40
+ if "error" in response_json:
41
+ raise RuntimeError(str(response_json["error"]).strip())
42
+
43
+ raise RuntimeError("Invalid model response format.")
44
+
45
+ def _call_huggingface_model(self, prompt: str, model_name: str) -> str:
46
+ api_url = f"https://api-inference.huggingface.co/models/{model_name}"
47
+ headers = {}
48
+
49
+ hf_token = getattr(settings, "HUGGINGFACE_API_TOKEN", "")
50
+ if hf_token:
51
+ headers["Authorization"] = f"Bearer {hf_token}"
52
+
53
+ payload = self._build_payload(prompt, model_name)
54
+
55
+ response = requests.post(
56
+ api_url,
57
+ headers=headers,
58
+ json=payload,
59
+ timeout=self.timeout,
60
+ )
61
+ response.raise_for_status()
62
+
63
+ return self._extract_text(response.json())
64
+
65
+ def generate(self, prompt: str) -> Tuple[str, str, bool]:
66
+ try:
67
+ output = self._call_huggingface_model(prompt, self.primary_model)
68
+ return output, self.primary_model, False
69
+ except Exception as primary_error:
70
+ print(f"Primary model failed: {primary_error}")
71
+
72
+ try:
73
+ output = self._call_huggingface_model(prompt, self.fallback_model)
74
+ return output, self.fallback_model, True
75
+ except Exception as fallback_error:
76
+ print(f"Fallback model failed: {fallback_error}")
77
+ raise RuntimeError("Both primary and fallback models failed.")
78
+
79
+
80
+ model_client = ModelClient()