Rahul-8799 commited on
Commit
d655b62
·
verified ·
1 Parent(s): d00cb32

Update utils/inference.py

Browse files
Files changed (1) hide show
  1. utils/inference.py +33 -7
utils/inference.py CHANGED
@@ -1,11 +1,37 @@
1
- from huggingface_hub import InferenceClient
2
  import os
 
3
 
4
- client = InferenceClient(
5
- model="bigcode/starcoder2-3b",
6
- token=os.environ.get("HF_TOKEN"),
7
- provider="together"
8
- )
 
 
 
9
 
10
  def call_model(prompt: str) -> str:
11
- return client.text_generation(prompt, max_new_tokens=2048, temperature=0.3, return_full_text=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import requests
3
 
4
+ # Replace this with your actual endpoint URL
5
+ API_URL = "https://lgj704z9p0j2vf79.us-east4.gcp.endpoints.huggingface.cloud" # e.g., https://mistral-rw123.hf.space
6
+ HF_ENDPOINT_TOKEN = os.environ.get("HF_ENDPOINT_TOKEN")
7
+
8
+ headers = {
9
+ "Authorization": f"Bearer {HF_ENDPOINT_TOKEN}",
10
+ "Content-Type": "application/json"
11
+ }
12
 
13
  def call_model(prompt: str) -> str:
14
+ response = requests.post(
15
+ f"{API_URL}/v1/completions", # Use `/v1/completions` or `/generate` depending on config
16
+ headers=headers,
17
+ json={
18
+ "inputs": prompt,
19
+ "parameters": {
20
+ "max_new_tokens": 2048,
21
+ "temperature": 0.3,
22
+ "do_sample": False
23
+ }
24
+ }
25
+ )
26
+
27
+ if response.status_code != 200:
28
+ raise RuntimeError(f"Inference error: {response.status_code} - {response.text}")
29
+
30
+ # The response schema may vary slightly; adjust if needed:
31
+ result = response.json()
32
+ if isinstance(result, dict) and "generated_text" in result:
33
+ return result["generated_text"]
34
+ elif isinstance(result, list) and "generated_text" in result[0]:
35
+ return result[0]["generated_text"]
36
+ else:
37
+ return result.get("data", "⚠️ No output generated.")