LordXido commited on
Commit
84b0021
·
verified ·
1 Parent(s): 647b435

Create llm_drivers.py

Browse files
Files changed (1) hide show
  1. llm_drivers.py +82 -0
llm_drivers.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import time
4
+
5
+ class LLMDriver:
6
+ name = "base"
7
+
8
+ def generate_code(self, task: str) -> str:
9
+ raise NotImplementedError
10
+
11
+
12
+ class GroqDriver(LLMDriver):
13
+ name = "groq"
14
+
15
+ def __init__(self):
16
+ self.api_key = os.getenv("GROQ_API_KEY")
17
+ self.endpoint = "https://api.groq.com/openai/v1/chat/completions"
18
+ self.model = os.getenv("GROQ_MODEL", "llama-3.1-8b-instant")
19
+
20
+ def generate_code(self, task):
21
+ headers = {
22
+ "Authorization": f"Bearer {self.api_key}",
23
+ "Content-Type": "application/json"
24
+ }
25
+
26
+ payload = {
27
+ "model": self.model,
28
+ "messages": [{"role": "user", "content": task}],
29
+ "max_tokens": 300,
30
+ "temperature": 0.2
31
+ }
32
+
33
+ r = requests.post(self.endpoint, headers=headers, json=payload, timeout=30)
34
+ r.raise_for_status()
35
+ return r.json()["choices"][0]["message"]["content"]
36
+
37
+
38
+ class OpenAIDriver(LLMDriver):
39
+ name = "openai"
40
+
41
+ def __init__(self):
42
+ self.api_key = os.getenv("OPENAI_API_KEY")
43
+ self.endpoint = "https://api.openai.com/v1/chat/completions"
44
+ self.model = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
45
+
46
+ def generate_code(self, task):
47
+ headers = {
48
+ "Authorization": f"Bearer {self.api_key}",
49
+ "Content-Type": "application/json"
50
+ }
51
+
52
+ payload = {
53
+ "model": self.model,
54
+ "messages": [{"role": "user", "content": task}],
55
+ "max_tokens": 300,
56
+ "temperature": 0.2
57
+ }
58
+
59
+ r = requests.post(self.endpoint, headers=headers, json=payload, timeout=30)
60
+ r.raise_for_status()
61
+ return r.json()["choices"][0]["message"]["content"]
62
+
63
+
64
+ class HuggingFaceDriver(LLMDriver):
65
+ name = "huggingface"
66
+
67
+ def __init__(self):
68
+ self.api_key = os.getenv("HF_API_TOKEN")
69
+ self.model = os.getenv("HF_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct")
70
+ self.endpoint = f"https://api-inference.huggingface.co/models/{self.model}"
71
+
72
+ def generate_code(self, task):
73
+ headers = {"Authorization": f"Bearer {self.api_key}"}
74
+ payload = {"inputs": task}
75
+
76
+ r = requests.post(self.endpoint, headers=headers, json=payload, timeout=30)
77
+ r.raise_for_status()
78
+ data = r.json()
79
+
80
+ if isinstance(data, list):
81
+ return data[0].get("generated_text", "")
82
+ return data.get("generated_text", "")