| import os |
| import requests |
| import time |
|
|
| class LLMDriver: |
| name = "base" |
|
|
| def generate_code(self, task: str) -> str: |
| raise NotImplementedError |
|
|
|
|
| class GroqDriver(LLMDriver): |
| name = "groq" |
|
|
| def __init__(self): |
| self.api_key = os.getenv("GROQ_API_KEY") |
| self.endpoint = "https://api.groq.com/openai/v1/chat/completions" |
| self.model = os.getenv("GROQ_MODEL", "llama-3.1-8b-instant") |
|
|
| def generate_code(self, task): |
| headers = { |
| "Authorization": f"Bearer {self.api_key}", |
| "Content-Type": "application/json" |
| } |
|
|
| payload = { |
| "model": self.model, |
| "messages": [{"role": "user", "content": task}], |
| "max_tokens": 300, |
| "temperature": 0.2 |
| } |
|
|
| r = requests.post(self.endpoint, headers=headers, json=payload, timeout=30) |
| r.raise_for_status() |
| return r.json()["choices"][0]["message"]["content"] |
|
|
|
|
| class OpenAIDriver(LLMDriver): |
| name = "openai" |
|
|
| def __init__(self): |
| self.api_key = os.getenv("OPENAI_API_KEY") |
| self.endpoint = "https://api.openai.com/v1/chat/completions" |
| self.model = os.getenv("OPENAI_MODEL", "gpt-4o-mini") |
|
|
| def generate_code(self, task): |
| headers = { |
| "Authorization": f"Bearer {self.api_key}", |
| "Content-Type": "application/json" |
| } |
|
|
| payload = { |
| "model": self.model, |
| "messages": [{"role": "user", "content": task}], |
| "max_tokens": 300, |
| "temperature": 0.2 |
| } |
|
|
| r = requests.post(self.endpoint, headers=headers, json=payload, timeout=30) |
| r.raise_for_status() |
| return r.json()["choices"][0]["message"]["content"] |
|
|
|
|
| class HuggingFaceDriver(LLMDriver): |
| name = "huggingface" |
|
|
| def __init__(self): |
| self.api_key = os.getenv("HF_API_TOKEN") |
| self.model = os.getenv("HF_MODEL", "meta-llama/Meta-Llama-3-8B-Instruct") |
| self.endpoint = f"https://api-inference.huggingface.co/models/{self.model}" |
|
|
| def generate_code(self, task): |
| headers = {"Authorization": f"Bearer {self.api_key}"} |
| payload = {"inputs": task} |
|
|
| r = requests.post(self.endpoint, headers=headers, json=payload, timeout=30) |
| r.raise_for_status() |
| data = r.json() |
|
|
| if isinstance(data, list): |
| return data[0].get("generated_text", "") |
| return data.get("generated_text", "") |