| | """ |
| | This class contains multiple LLMs and handles LLMs response |
| | """ |
| |
|
| | import json |
| | import time |
| | from openai import OpenAI |
| | import openai |
| | import torch |
| | import re |
| | import anthropic |
| | import os |
| | import streamlit as st |
| | from google.genai import types |
| | from google import genai |
| |
|
| |
|
| |
|
| | class LLM: |
| | def __init__(self, Core): |
| | self.Core = Core |
| | self.model = None |
| | self.model_type = "openai" |
| | self.client = None |
| | self.connect_to_llm() |
| |
|
| | def get_credential(self, key): |
| | return os.getenv(key) or st.secrets.get(key) |
| |
|
| | def get_response(self, prompt, instructions): |
| | if self.model_type == "openai": |
| | response = self.get_message_openai(prompt, instructions) |
| | |
| | |
| | elif self.model_type == "inference": |
| | response = self.get_message_inference(prompt, instructions) |
| | elif self.model_type == "claude": |
| | response = self.get_message_claude(prompt, instructions) |
| | elif self.model_type == "google": |
| | response = self.get_message_google(prompt, instructions) |
| | else: |
| | raise f"Invalid model type : {self.model_type}" |
| |
|
| | return response |
| |
|
| | def connect_to_llm(self): |
| | """ |
| | connect to selected llm -> ollama or openai connection |
| | :return: |
| | """ |
| |
|
| | if self.Core.model in self.Core.config_file["openai_models"]: |
| | self.model_type = "openai" |
| |
|
| | elif self.Core.model in self.Core.config_file["inference_models"]: |
| | self.model_type = "inference" |
| |
|
| | elif self.Core.model in self.Core.config_file["google_models"]: |
| | self.model_type = "google" |
| |
|
| | |
| | |
| | |
| |
|
| | elif self.Core.model in self.Core.config_file["claude_models"]: |
| | self.model_type = "claude" |
| | self.client = anthropic.Anthropic( |
| | api_key=self.get_credential('claude_api_key'), |
| | ) |
| |
|
| | self.model = self.Core.model |
| |
|
| | |
| | def get_message_inference(self, prompt, instructions, max_retries=6): |
| | """ |
| | sending the prompt to openai LLM and get back the response |
| | """ |
| |
|
| | api_key = self.get_credential('inference_api_key') |
| | client = OpenAI( |
| | base_url="https://api.inference.net/v1", |
| | api_key=api_key, |
| | ) |
| |
|
| | for attempt in range(max_retries): |
| | try: |
| | if self.Core.reasoning_model: |
| | response = client.chat.completions.create( |
| | model=self.Core.model, |
| | response_format={"type": "json_object"}, |
| | messages=[ |
| | {"role": "system", "content": instructions}, |
| | {"role": "user", "content": prompt} |
| | ], |
| | reasoning_effort="medium", |
| | n=1, |
| | ) |
| |
|
| | else: |
| | response = client.chat.completions.create( |
| | model=self.Core.model, |
| | response_format={"type": "json_object"}, |
| | messages=[ |
| | {"role": "system", "content": instructions}, |
| | {"role": "user", "content": prompt} |
| | ], |
| | n=1, |
| | temperature=self.Core.temperature |
| | ) |
| |
|
| | tokens = { |
| | 'prompt_tokens': response.usage.prompt_tokens, |
| | 'completion_tokens': response.usage.completion_tokens, |
| | 'total_tokens': response.usage.total_tokens |
| | } |
| |
|
| | |
| | self.Core.total_tokens['prompt_tokens'] += tokens['prompt_tokens'] |
| | self.Core.total_tokens['completion_tokens'] += tokens['completion_tokens'] |
| | self.Core.temp_token_counter += tokens['total_tokens'] |
| |
|
| | try: |
| | content = response.choices[0].message.content |
| |
|
| | |
| |
|
| | output = json.loads(content) |
| |
|
| | if 'message' not in output or 'header' not in output: |
| | print(f"'message' or 'header' is missing in response on attempt {attempt + 1}. Retrying...") |
| | continue |
| |
|
| | else: |
| | if len(output["header"].strip()) > self.Core.config_file["header_limit"] or len( |
| | output["message"].strip()) > self.Core.config_file["message_limit"]: |
| | print( |
| | f"'header' or 'message' is more than specified characters in response on attempt {attempt + 1}. Retrying...") |
| | continue |
| |
|
| | return output |
| |
|
| | except json.JSONDecodeError: |
| | print(f"Invalid JSON from LLM on attempt {attempt + 1}. Retrying...") |
| |
|
| | except openai.APIConnectionError as e: |
| | print("The server could not be reached") |
| | print(e.__cause__) |
| | except openai.RateLimitError as e: |
| | print("A 429 status code was received; we should back off a bit.") |
| | except openai.APIStatusError as e: |
| | print("Another non-200-range status code was received") |
| | print(e.status_code) |
| | print(e.response) |
| |
|
| | print("Max retries exceeded. Returning empty response.") |
| | return None |
| |
|
| | |
| | def get_message_google(self, prompt, instructions, max_retries=6): |
| |
|
| | client = genai.Client(api_key=self.get_credential("Google_API")) |
| |
|
| | for attempt in range(max_retries): |
| | try: |
| | response = client.models.generate_content( |
| | model=self.Core.model, |
| | contents=prompt, |
| | config=types.GenerateContentConfig( |
| | thinking_config=types.ThinkingConfig(thinking_budget=0), |
| | system_instruction=instructions, |
| | temperature=self.Core.temperature, |
| | response_mime_type="application/json" |
| | )) |
| |
|
| | |
| | tokens = { |
| | 'prompt_tokens': response.usage_metadata.prompt_token_count, |
| | 'completion_tokens': response.usage_metadata.candidates_token_count, |
| | 'total_tokens': response.usage_metadata.total_token_count |
| | } |
| |
|
| | |
| | self.Core.total_tokens['prompt_tokens'] += tokens['prompt_tokens'] |
| | self.Core.total_tokens['completion_tokens'] += tokens['completion_tokens'] |
| | self.Core.temp_token_counter += tokens['total_tokens'] |
| |
|
| | output = self.preprocess_and_parse_json(response.text) |
| |
|
| | if 'message' not in output or 'header' not in output: |
| | print(f"'message' or 'header' is missing in response on attempt {attempt + 1}. Retrying...") |
| | continue |
| |
|
| | else: |
| | if len(output["header"].strip()) > self.Core.config_file["header_limit"] or len( |
| | output["message"].strip()) > self.Core.config_file["message_limit"]: |
| | print( |
| | f"'header' or 'message' is more than specified characters in response on attempt {attempt + 1}. Retrying...") |
| | continue |
| | return output |
| |
|
| | except json.JSONDecodeError: |
| | print(f"Invalid JSON from LLM on attempt {attempt + 1}. Retrying...") |
| | except Exception as e: |
| | print(f"Error in attempt {attempt}: {e}") |
| |
|
| | print("Max retries exceeded. Returning empty response.") |
| | return None |
| |
|
| | |
| |
|
| | def get_message_openai(self, prompt, instructions, max_retries=5): |
| | """ |
| | sending the prompt to openai LLM and get back the response |
| | """ |
| |
|
| | openai.api_key = self.Core.api_key |
| | client = OpenAI(api_key=self.Core.api_key) |
| |
|
| | for attempt in range(max_retries): |
| | try: |
| | if self.Core.reasoning_model: |
| | response = client.chat.completions.create( |
| | model=self.Core.model, |
| | response_format={"type": "json_object"}, |
| | messages=[ |
| | {"role": "system", "content": instructions}, |
| | {"role": "user", "content": prompt} |
| | ], |
| | reasoning_effort="minimal", |
| | n=1, |
| | ) |
| |
|
| | else: |
| | response = client.chat.completions.create( |
| | model=self.Core.model, |
| | response_format={"type": "json_object"}, |
| | messages=[ |
| | {"role": "system", "content": instructions}, |
| | {"role": "user", "content": prompt} |
| | ], |
| | n=1, |
| | temperature=self.Core.temperature |
| | ) |
| |
|
| | tokens = { |
| | 'prompt_tokens': response.usage.prompt_tokens, |
| | 'completion_tokens': response.usage.completion_tokens, |
| | 'total_tokens': response.usage.total_tokens |
| | } |
| |
|
| | |
| | self.Core.total_tokens['prompt_tokens'] += tokens['prompt_tokens'] |
| | self.Core.total_tokens['completion_tokens'] += tokens['completion_tokens'] |
| | self.Core.temp_token_counter += tokens['total_tokens'] |
| |
|
| | try: |
| | content = response.choices[0].message.content |
| |
|
| | |
| |
|
| | output = json.loads(content) |
| |
|
| | if 'message' not in output or 'header' not in output: |
| | print(f"'message' or 'header' is missing in response on attempt {attempt + 1}. Retrying...") |
| | continue |
| |
|
| | else: |
| | if len(output["header"].strip()) > self.Core.config_file["header_limit"] or len( |
| | output["message"].strip()) > self.Core.config_file["message_limit"]: |
| | print( |
| | f"'header' or 'message' is more than specified characters in response on attempt {attempt + 1}. Retrying...") |
| | continue |
| |
|
| | return output |
| |
|
| | except json.JSONDecodeError: |
| | print(f"Invalid JSON from LLM on attempt {attempt + 1}. Retrying...") |
| |
|
| | except openai.APIConnectionError as e: |
| | print("The server could not be reached") |
| | print(e.__cause__) |
| | except openai.RateLimitError as e: |
| | print("A 429 status code was received; we should back off a bit.") |
| | except openai.APIStatusError as e: |
| | print("Another non-200-range status code was received") |
| | print(e.status_code) |
| | print(e.response) |
| |
|
| | print("Max retries exceeded. Returning empty response.") |
| | return None |
| |
|
| | |
| |
|
| | def get_message_ollama(self, prompt, instructions, max_retries=10): |
| | """ |
| | Send the prompt to the LLM and get back the response. |
| | Includes handling for GPU memory issues by clearing cache and waiting before retry. |
| | """ |
| | prompt = instructions + "\n \n" + prompt |
| | for attempt in range(max_retries): |
| | try: |
| | |
| | response = self.client.generate(model=self.model, prompt=prompt) |
| | except Exception as e: |
| | |
| | print(f"Error on attempt {attempt + 1}: {e}.") |
| | try: |
| | |
| | torch.cuda.empty_cache() |
| | print("Cleared GPU cache.") |
| | except Exception as cache_err: |
| | print("Failed to clear GPU cache:", cache_err) |
| | |
| | time.sleep(2) |
| | continue |
| |
|
| | try: |
| | tokens = { |
| | 'prompt_tokens': 0, |
| | 'completion_tokens': 0, |
| | 'total_tokens': 0 |
| | } |
| |
|
| | try: |
| | output = self.preprocess_and_parse_json(response.response) |
| | if output is None: |
| | continue |
| |
|
| | if 'message' not in output or 'header' not in output: |
| | print(f"'message' or 'header' is missing in response on attempt {attempt + 1}. Retrying...") |
| | continue |
| |
|
| | else: |
| | if len(output["header"].strip()) > self.Core.config_file["header_limit"] or len( |
| | output["message"].strip()) > self.Core.config_file["message_limit"]: |
| | print( |
| | f"'header' or 'message' is more than specified characters in response on attempt {attempt + 1}. Retrying...") |
| | continue |
| | else: |
| | return output |
| |
|
| | except json.JSONDecodeError: |
| | print(f"Invalid JSON from LLM on attempt {attempt + 1}. Retrying...") |
| | except Exception as parse_error: |
| | print("Error processing output:", parse_error) |
| |
|
| | print("Max retries exceeded. Returning empty response.") |
| | return None |
| |
|
| | def get_message_claude(self, prompt, instructions, max_retries=6): |
| | """ |
| | send prompt to claude LLM and get back the response |
| | :param prompt: |
| | :param instructions: |
| | :return: |
| | """ |
| |
|
| |
|
| | for attempt in range(max_retries): |
| | try: |
| |
|
| | message = self.client.messages.create( |
| | model=self.model, |
| | max_tokens=4096, |
| | system = instructions, |
| | messages=[ |
| | {"role": "user", "content": prompt + "\nHere is the JSON requested:\n"} |
| | ], |
| | temperature=self.Core.temperature |
| | ) |
| | |
| | response = message.content[0].text |
| |
|
| | tokens = { |
| | 'prompt_tokens': message.usage.input_tokens, |
| | 'completion_tokens': message.usage.output_tokens, |
| | 'total_tokens': message.usage.output_tokens + message.usage.input_tokens |
| | } |
| |
|
| | self.Core.total_tokens['prompt_tokens'] += tokens['prompt_tokens'] |
| | self.Core.total_tokens['completion_tokens'] += tokens['completion_tokens'] |
| | self.Core.temp_token_counter += tokens['total_tokens'] |
| |
|
| | try: |
| | output = self.preprocess_and_parse_json_claude(response) |
| | if output is None: |
| | continue |
| |
|
| | if 'message' not in output or 'header' not in output: |
| | print(f"'message' or 'header' is missing in response on attempt {attempt + 1}. Retrying...") |
| | continue |
| |
|
| | else: |
| | if len(output["header"].strip()) > self.Core.config_file["header_limit"] or len( |
| | output["message"].strip()) > self.Core.config_file["message_limit"]: |
| | print( |
| | f"'header' or 'message' is more than specified characters in response on attempt {attempt + 1}. Retrying...") |
| | continue |
| | else: |
| | return output |
| |
|
| | except json.JSONDecodeError: |
| | print(f"Invalid JSON from LLM on attempt {attempt + 1}. Retrying...") |
| | except Exception as parse_error: |
| | print("Error processing output:", parse_error) |
| |
|
| | print("Max retries exceeded. Returning empty response.") |
| | return None |
| |
|
| | |
| |
|
| | def preprocess_and_parse_json(self, response: str): |
| | """ |
| | Remove <think> blocks, extract JSON (from ```json fences or first {...} block), |
| | and parse. Includes a repair pass to handle common LLM issues like trailing commas. |
| | """ |
| |
|
| | def extract_json(text: str) -> str: |
| | |
| | text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip() |
| |
|
| | |
| | fence = re.search(r'```(?:json)?(.*?)```', text, flags=re.DOTALL | re.IGNORECASE) |
| | if fence: |
| | return fence.group(1).strip() |
| |
|
| | |
| | brace = re.search(r'\{.*\}', text, flags=re.DOTALL) |
| | return brace.group(0).strip() if brace else text.strip() |
| |
|
| | def normalize_quotes(text: str) -> str: |
| | return (text |
| | .replace('\ufeff', '') |
| | .replace('“', '"').replace('”', '"') |
| | .replace('‘', "'").replace('’', "'")) |
| |
|
| | def strip_comments(text: str) -> str: |
| | |
| | text = re.sub(r'//.*?$', '', text, flags=re.MULTILINE) |
| | text = re.sub(r'/\*.*?\*/', '', text, flags=re.DOTALL) |
| | return text |
| |
|
| | def remove_trailing_commas(text: str) -> str: |
| | |
| | return re.sub(r',(\s*[}\]])', r'\1', text) |
| |
|
| | raw = extract_json(response) |
| | raw = normalize_quotes(raw) |
| |
|
| | try: |
| | return json.loads(raw) |
| | except json.JSONDecodeError: |
| | |
| | repaired = strip_comments(raw) |
| | repaired = remove_trailing_commas(repaired) |
| | repaired = repaired.strip() |
| |
|
| | try: |
| | return json.loads(repaired) |
| | except json.JSONDecodeError as e: |
| | print(f"Failed to parse JSON: {e}") |
| | |
| | return None |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | def preprocess_and_parse_json_claude(self, response: str): |
| | """ |
| | Process Claude response and extract JSON content safely |
| | """ |
| | try: |
| | json_start = response.index("{") |
| | json_end = response.rfind("}") |
| | json_string = response[json_start:json_end + 1] |
| |
|
| | parsed_response = json.loads(json_string) |
| |
|
| | if not isinstance(parsed_response, dict): |
| | raise ValueError(f"Parsed response is not a dict: {parsed_response}") |
| |
|
| | return parsed_response |
| |
|
| | except ValueError as ve: |
| | raise ValueError(f"Could not extract JSON from Claude response: {ve}\nOriginal response: {response}") |
| | except json.JSONDecodeError as je: |
| | raise ValueError(f"Failed to parse JSON from string: {json_string}\nError: {je}") |
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|