cole / src /language_model /open_ai_api_lm_wrapper.py
davebulaval's picture
v1
8fa3acc
# pylint: disable=inconsistent-return-statements
import logging
import time
from abc import ABC, abstractmethod
from typing import List, Dict
from src import NA_VALUE
from src.language_model.init_function_calling import init_function_calling
class OpenAIAPILMWrapper(ABC):
def __init__(
self,
model_name: str,
extra_params: Dict,
use_function_calling: bool = True,
max_retries: int = 10,
):
self.model_name = model_name
self._extra_params = extra_params
self.max_retries = max_retries
self.use_function_calling = use_function_calling
self.tool_choices = "required"
self._none_prediction_counter = 0
self._max_retries_counter = 0
@abstractmethod
def _inner_generate_fn(self, prompt: List) -> Dict:
pass
def init_function_calling(self, labels: List[str], tool_choices: str) -> None:
if self.use_function_calling:
open_ai_api_call = (
"claude" not in self.model_name.lower()
) # False for Anthropic model, True for the rest
self._extra_params.update(
init_function_calling(
labels=labels,
tool_choices=tool_choices,
open_ai_api_call=open_ai_api_call,
)
)
def predict(self, text: str) -> Dict:
prompt = self.format_prompt(text)
generated_completion = self.language_model_calling(prompt=prompt)
final_prediction = self.extract_final_prediction(generated_completion)
return {"prediction": final_prediction}
@staticmethod
def format_prompt(text: str) -> List:
return [
{
"role": "user",
"content": [
{"type": "text", "text": text},
],
}
]
def language_model_calling(self, prompt: List):
return self._try_again(prompt=prompt)
def _try_again(self, prompt: List, retries: int = 0):
generated_completion = None
if retries > self.max_retries:
logging_message = f"Max retries exceeded: {retries}."
logging.warning(logging_message)
self._max_retries_counter += 1
return generated_completion
try:
generated_completion = self._inner_generate_fn(prompt=prompt)
return generated_completion
except:
time.sleep(5)
self._try_again(prompt=prompt, retries=retries + 1)
def extract_final_prediction(self, generated_completion) -> str:
if "claude" in self.model_name.lower():
if generated_completion is None:
final_prediction = None
elif generated_completion.content is None:
final_prediction = None
elif generated_completion.content[0].input is None:
final_prediction = None
else:
try:
final_prediction = generated_completion.content[0].input.get(
"category"
)
except:
# Case where the prediction is not a proper dictionary.
final_prediction = generated_completion.content[0].input
else:
# Something the completion is incomplete, thus we validate that components are there.
if generated_completion is None:
final_prediction = None
elif generated_completion.choices is None:
final_prediction = None
elif generated_completion.choices[0].message is None:
final_prediction = None
elif generated_completion.choices[0].message.tool_calls is None:
if generated_completion.choices[0].message.content is None:
final_prediction = None
else:
# No tools call, but potentially a response in the raw message content.
final_prediction = (
generated_completion.choices[0]
.message.content.strip()
.replace(")", "")
.strip()
)
else:
prediction = (
generated_completion.choices[0]
.message.tool_calls[0]
.function.arguments
)
try:
final_prediction = eval(prediction).get("category")
except:
# Case where the prediction is not a proper dictionary.
final_prediction = prediction
if final_prediction is None:
# Case were final prediction is None, thus we return -1 to be able to be converted
# as int if necessary (infer-case) or left as string (generate-case).
# Thus, in both case, it will not yield better results.
self._none_prediction_counter += 1
final_prediction = f"{NA_VALUE}"
elif "La réponse est" in final_prediction or ":" in final_prediction:
# To handle case where the LLM return the premise to the last query.
final_prediction = final_prediction.split(":")[-1].strip().replace(" ", "")
# Cases where the response is accompanied by other string elements, but it should be a single digit.
elif "0" in final_prediction:
final_prediction = "0"
elif "1" in final_prediction:
final_prediction = "1"
elif "2" in final_prediction:
final_prediction = "2"
elif "3" in final_prediction:
final_prediction = "3"
elif "4" in final_prediction:
final_prediction = "4"
elif "5" in final_prediction:
final_prediction = "5"
elif "6" in final_prediction:
final_prediction = "6"
elif "7" in final_prediction:
final_prediction = "7"
elif "8" in final_prediction:
final_prediction = "8"
elif "9" in final_prediction:
final_prediction = "9"
elif "10" in final_prediction:
final_prediction = "10"
elif "11" in final_prediction:
final_prediction = "11"
return final_prediction
def print_none(self) -> None:
if self._none_prediction_counter > 0:
logging_message = f"Number of None: {self._none_prediction_counter}."
logging.warning(logging_message)
if self._max_retries_counter > 0:
logging_message = f"Number of max retries exceeded occurrence: {self._max_retries_counter}."
logging.warning(logging_message)