LordXido's picture
Create llm_drivers.py
84b0021 verified
import os
import requests
import time
class LLMDriver:
name = "base"
def generate_code(self, task: str) -> str:
raise NotImplementedError
class GroqDriver(LLMDriver):
name = "groq"
def __init__(self):
self.api_key = os.getenv("GROQ_API_KEY")
self.endpoint = "https://api.groq.com/openai/v1/chat/completions"
self.model = os.getenv("GROQ_MODEL", "llama-3.1-8b-instant")
def generate_code(self, task):
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"messages": [{"role": "user", "content": task}],
"max_tokens": 300,
"temperature": 0.2
}
r = requests.post(self.endpoint, headers=headers, json=payload, timeout=30)
r.raise_for_status()
return r.json()["choices"][0]["message"]["content"]
class OpenAIDriver(LLMDriver):
name = "openai"
def __init__(self):
self.api_key = os.getenv("OPENAI_API_KEY")
self.endpoint = "https://api.openai.com/v1/chat/completions"
self.model = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
def generate_code(self, task):
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": self.model,
"messages": [{"role": "user", "content": task}],
"max_tokens": 300,
"temperature": 0.2
}
r = requests.post(self.endpoint, headers=headers, json=payload, timeout=30)
r.raise_for_status()
return r.json()["choices"][0]["message"]["content"]
class HuggingFaceDriver(LLMDriver):
name = "huggingface"
def __init__(self):
self.api_key = os.getenv("HF_API_TOKEN")
self.model = os.getenv("HF_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct")
self.endpoint = f"https://api-inference.huggingface.co/models/{self.model}"
def generate_code(self, task):
headers = {"Authorization": f"Bearer {self.api_key}"}
payload = {"inputs": task}
r = requests.post(self.endpoint, headers=headers, json=payload, timeout=30)
r.raise_for_status()
data = r.json()
if isinstance(data, list):
return data[0].get("generated_text", "")
return data.get("generated_text", "")