|
|
|
|
|
import os
|
|
|
from typing import Union
|
|
|
import src.translation_agent.utils as utils
|
|
|
|
|
|
from llama_index.llms.groq import Groq
|
|
|
from llama_index.llms.cohere import Cohere
|
|
|
from llama_index.llms.openai import OpenAI
|
|
|
from llama_index.llms.together import TogetherLLM
|
|
|
from llama_index.llms.ollama import Ollama
|
|
|
from llama_index.llms.huggingface_api import HuggingFaceInferenceAPI
|
|
|
|
|
|
from llama_index.core import Settings
|
|
|
from llama_index.core.llms import ChatMessage
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def model_load(
|
|
|
endpoint: str,
|
|
|
model: str,
|
|
|
api_key: str = None,
|
|
|
context_window: int = 4096,
|
|
|
num_output: int = 512,
|
|
|
):
|
|
|
if endpoint == "Groq":
|
|
|
llm = Groq(
|
|
|
model=model,
|
|
|
api_key=api_key if api_key else os.getenv("GROQ_API_KEY"),
|
|
|
)
|
|
|
elif endpoint == "Cohere":
|
|
|
llm = Cohere(
|
|
|
model=model,
|
|
|
api_key=api_key if api_key else os.getenv("COHERE_API_KEY"),
|
|
|
)
|
|
|
elif endpoint == "OpenAI":
|
|
|
llm = OpenAI(
|
|
|
model=model,
|
|
|
api_key=api_key if api_key else os.getenv("OPENAI_API_KEY"),
|
|
|
)
|
|
|
elif endpoint == "TogetherAI":
|
|
|
llm = TogetherLLM(
|
|
|
model=model,
|
|
|
api_key=api_key if api_key else os.getenv("TOGETHER_API_KEY"),
|
|
|
)
|
|
|
elif endpoint == "Ollama":
|
|
|
llm = Ollama(
|
|
|
model=model,
|
|
|
request_timeout=120.0)
|
|
|
elif endpoint == "Huggingface":
|
|
|
llm = HuggingFaceInferenceAPI(
|
|
|
model_name=model,
|
|
|
token=api_key if api_key else os.getenv("HF_TOKEN"),
|
|
|
task="text-generation",
|
|
|
)
|
|
|
Settings.llm = llm
|
|
|
|
|
|
Settings.context_window = context_window
|
|
|
|
|
|
|
|
|
Settings.num_output = num_output
|
|
|
|
|
|
|
|
|
def get_completion(
|
|
|
prompt: str,
|
|
|
system_message: str = "You are a helpful assistant.",
|
|
|
temperature: float = 0.3,
|
|
|
json_mode: bool = False,
|
|
|
) -> Union[str, dict]:
|
|
|
"""
|
|
|
Generate a completion using the OpenAI API.
|
|
|
|
|
|
Args:
|
|
|
prompt (str): The user's prompt or query.
|
|
|
system_message (str, optional): The system message to set the context for the assistant.
|
|
|
Defaults to "You are a helpful assistant.".
|
|
|
temperature (float, optional): The sampling temperature for controlling the randomness of the generated text.
|
|
|
Defaults to 0.3.
|
|
|
json_mode (bool, optional): Whether to return the response in JSON format.
|
|
|
Defaults to False.
|
|
|
|
|
|
Returns:
|
|
|
Union[str, dict]: The generated completion.
|
|
|
If json_mode is True, returns the complete API response as a dictionary.
|
|
|
If json_mode is False, returns the generated text as a string.
|
|
|
"""
|
|
|
llm = Settings.llm
|
|
|
if llm.class_name() == "HuggingFaceInferenceAPI":
|
|
|
llm.system_prompt = system_message
|
|
|
messages = [
|
|
|
ChatMessage(
|
|
|
role="user", content=prompt),
|
|
|
]
|
|
|
response = llm.chat(
|
|
|
messages=messages,
|
|
|
temperature=temperature,
|
|
|
)
|
|
|
return response.message.content
|
|
|
else:
|
|
|
messages = [
|
|
|
ChatMessage(
|
|
|
role="system", content=system_message),
|
|
|
ChatMessage(
|
|
|
role="user", content=prompt),
|
|
|
]
|
|
|
|
|
|
if json_mode:
|
|
|
response = llm.chat(
|
|
|
temperature=temperature,
|
|
|
response_format={"type": "json_object"},
|
|
|
messages=messages,
|
|
|
)
|
|
|
return response.message.content
|
|
|
else:
|
|
|
response = llm.chat(
|
|
|
temperature=temperature,
|
|
|
messages=messages,
|
|
|
)
|
|
|
return response.message.content
|
|
|
|
|
|
utils.get_completion = get_completion
|
|
|
|
|
|
one_chunk_initial_translation = utils.one_chunk_initial_translation
|
|
|
one_chunk_reflect_on_translation = utils.one_chunk_reflect_on_translation
|
|
|
one_chunk_improve_translation = utils.one_chunk_improve_translation
|
|
|
one_chunk_translate_text = utils.one_chunk_translate_text
|
|
|
num_tokens_in_string = utils.num_tokens_in_string
|
|
|
multichunk_initial_translation = utils.multichunk_initial_translation
|
|
|
multichunk_reflect_on_translation = utils.multichunk_reflect_on_translation
|
|
|
multichunk_improve_translation = utils.multichunk_improve_translation
|
|
|
multichunk_translation = utils.multichunk_translation
|
|
|
calculate_chunk_size =utils.calculate_chunk_size |