# utils.py
import re
import subprocess
import os
from typing import Optional, Any, Type, TypedDict, List
from pydantic import BaseModel, Field
from langchain.chat_models import init_chat_model
from langchain_community.vectorstores import FAISS
from langchain_openai.embeddings import OpenAIEmbeddings
from langchain_aws import ChatBedrock, ChatBedrockConverse
from langchain_anthropic import ChatAnthropic
from pathlib import Path
import tracking_aws
import requests
import time
import random
from botocore.exceptions import ClientError
import shutil
from config import Config
from langchain_ollama import ChatOllama
# Global dictionary to store loaded FAISS databases
FAISS_DB_CACHE = {}
DATABASE_DIR = f"{Path(__file__).resolve().parent.parent}/database/faiss"
FAISS_DB_CACHE = {
"openfoam_allrun_scripts": FAISS.load_local(f"{DATABASE_DIR}/openfoam_allrun_scripts", OpenAIEmbeddings(model="text-embedding-3-small"), allow_dangerous_deserialization=True),
"openfoam_tutorials_structure": FAISS.load_local(f"{DATABASE_DIR}/openfoam_tutorials_structure", OpenAIEmbeddings(model="text-embedding-3-small"), allow_dangerous_deserialization=True),
"openfoam_tutorials_details": FAISS.load_local(f"{DATABASE_DIR}/openfoam_tutorials_details", OpenAIEmbeddings(model="text-embedding-3-small"), allow_dangerous_deserialization=True),
"openfoam_command_help": FAISS.load_local(f"{DATABASE_DIR}/openfoam_command_help", OpenAIEmbeddings(model="text-embedding-3-small"), allow_dangerous_deserialization=True)
}
class FoamfilePydantic(BaseModel):
file_name: str = Field(description="Name of the OpenFOAM input file")
folder_name: str = Field(description="Folder where the foamfile should be stored")
content: str = Field(description="Content of the OpenFOAM file, written in OpenFOAM dictionary format")
class FoamPydantic(BaseModel):
list_foamfile: List[FoamfilePydantic] = Field(description="List of OpenFOAM configuration files")
class ResponseWithThinkPydantic(BaseModel):
think: str = Field(description="Thought process of the LLM")
response: str = Field(description="Response of the LLM")
class LLMService:
def __init__(self, config: object):
self.model_version = getattr(config, "model_version", "gpt-4o")
self.temperature = getattr(config, "temperature", 0)
self.model_provider = getattr(config, "model_provider", "openai")
# Initialize statistics
self.total_calls = 0
self.total_prompt_tokens = 0
self.total_completion_tokens = 0
self.total_tokens = 0
self.failed_calls = 0
self.retry_count = 0
# Initialize the LLM
if self.model_provider.lower() == "bedrock":
bedrock_runtime = tracking_aws.new_default_client()
self.llm = ChatBedrockConverse(
client=bedrock_runtime,
model_id=self.model_version,
temperature=self.temperature,
max_tokens=8192
)
elif self.model_provider.lower() == "anthropic":
self.llm = ChatAnthropic(
model=self.model_version,
temperature=self.temperature
)
elif self.model_provider.lower() == "openai":
self.llm = init_chat_model(
self.model_version,
model_provider=self.model_provider,
temperature=self.temperature
)
elif self.model_provider.lower() == "ollama":
try:
response = requests.get("http://localhost:11434/api/version", timeout=2)
# If request successful, service is running
except requests.exceptions.RequestException:
print("Ollama is not running, starting it...")
subprocess.Popen(["ollama", "serve"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
# Wait for service to start
time.sleep(5) # Give it 3 seconds to initialize
self.llm = ChatOllama(
model=self.model_version,
temperature=self.temperature,
num_predict=-1,
num_ctx=131072,
base_url="http://localhost:11434"
)
else:
raise ValueError(f"{self.model_provider} is not a supported model_provider")
def invoke(self,
user_prompt: str,
system_prompt: Optional[str] = None,
pydantic_obj: Optional[Type[BaseModel]] = None,
max_retries: int = 10) -> Any:
"""
Invoke the LLM with the given prompts and return the response.
Args:
user_prompt: The user's prompt
system_prompt: Optional system prompt
pydantic_obj: Optional Pydantic model for structured output
max_retries: Maximum number of retries for throttling errors
Returns:
The LLM response with token usage statistics
"""
self.total_calls += 1
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": user_prompt})
# Calculate prompt tokens
prompt_tokens = 0
for message in messages:
prompt_tokens += self.llm.get_num_tokens(message["content"])
retry_count = 0
while True:
try:
if pydantic_obj:
structured_llm = self.llm.with_structured_output(pydantic_obj)
response = structured_llm.invoke(messages)
else:
if self.model_version.startswith("deepseek"):
structured_llm = self.llm.with_structured_output(ResponseWithThinkPydantic)
response = structured_llm.invoke(messages)
# Extract the resposne without the think
response = response.response
else:
response = self.llm.invoke(messages)
response = response.content
# Calculate completion tokens
response_content = str(response)
completion_tokens = self.llm.get_num_tokens(response_content)
total_tokens = prompt_tokens + completion_tokens
# Update statistics
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
self.total_tokens += total_tokens
return response
except ClientError as e:
if e.response['Error']['Code'] == 'Throttling' or e.response['Error']['Code'] == 'TooManyRequestsException':
retry_count += 1
self.retry_count += 1
if retry_count > max_retries:
self.failed_calls += 1
raise Exception(f"Maximum retries ({max_retries}) exceeded: {str(e)}")
base_delay = 1.0
max_delay = 60.0
delay = min(max_delay, base_delay * (2 ** (retry_count - 1)))
jitter = random.uniform(0, 0.1 * delay)
sleep_time = delay + jitter
print(f"ThrottlingException occurred: {str(e)}. Retrying in {sleep_time:.2f} seconds (attempt {retry_count}/{max_retries})")
time.sleep(sleep_time)
else:
self.failed_calls += 1
raise e
except Exception as e:
self.failed_calls += 1
raise e
def get_statistics(self) -> dict:
"""
Get the current statistics of the LLM service.
Returns:
Dictionary containing various statistics
"""
return {
"total_calls": self.total_calls,
"failed_calls": self.failed_calls,
"retry_count": self.retry_count,
"total_prompt_tokens": self.total_prompt_tokens,
"total_completion_tokens": self.total_completion_tokens,
"total_tokens": self.total_tokens,
"average_prompt_tokens": self.total_prompt_tokens / self.total_calls if self.total_calls > 0 else 0,
"average_completion_tokens": self.total_completion_tokens / self.total_calls if self.total_calls > 0 else 0,
"average_tokens": self.total_tokens / self.total_calls if self.total_calls > 0 else 0
}
def print_statistics(self) -> None:
"""
Print the current statistics of the LLM service.
"""
stats = self.get_statistics()
print("\n