| import yaml | |
| import os | |
| import os | |
| import dotenv | |
| dotenv.load_dotenv() | |
| from openai import OpenAI | |
| from utils import count_tokens | |
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") or "" | |
| class PromptCatalog: | |
| def __init__(self, path=os.path.join(os.path.dirname(__file__), "..", "prompt_catalog.json")): | |
| with open(path, "r") as f: | |
| self.prompts = yaml.safe_load(f) | |
| def get_prompt(self, prompt_id, version=None): | |
| candidates = [p for p in self.prompts if p["id"] == prompt_id] | |
| if version: | |
| candidates = [p for p in candidates if p["version"] == version] | |
| if not candidates: | |
| raise ValueError(f"Prompt {prompt_id} v{version} not found") | |
| return candidates[0]["content"] | |
| class LLM: | |
| def __init__(self, model_name, prompt_id, version=None): | |
| self.client = OpenAI(api_key=OPENAI_API_KEY) | |
| self.model_name = model_name | |
| self.prompt_cataglog = PromptCatalog() | |
| self.system_prompt = self.prompt_cataglog.get_prompt(prompt_id, version) | |
| self.system_prompt_tokens = count_tokens(self.system_prompt) | |
| print(f"System prompt tokens: {self.system_prompt_tokens}") | |
| def generate_response(self, input_text): | |
| response = self.client.chat.completions.create( | |
| model=self.model_name, | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": self.system_prompt | |
| }, | |
| { | |
| "role": "user", | |
| "content": input_text | |
| } | |
| ], | |
| ) | |
| return response.choices[0].message.content | |