Spaces:
Running
Running
File size: 6,742 Bytes
8fa3acc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
# pylint: disable=inconsistent-return-statements
import logging
import time
from abc import ABC, abstractmethod
from typing import List, Dict
from src import NA_VALUE
from src.language_model.init_function_calling import init_function_calling
class OpenAIAPILMWrapper(ABC):
def __init__(
self,
model_name: str,
extra_params: Dict,
use_function_calling: bool = True,
max_retries: int = 10,
):
self.model_name = model_name
self._extra_params = extra_params
self.max_retries = max_retries
self.use_function_calling = use_function_calling
self.tool_choices = "required"
self._none_prediction_counter = 0
self._max_retries_counter = 0
@abstractmethod
def _inner_generate_fn(self, prompt: List) -> Dict:
pass
def init_function_calling(self, labels: List[str], tool_choices: str) -> None:
if self.use_function_calling:
open_ai_api_call = (
"claude" not in self.model_name.lower()
) # False for Anthropic model, True for the rest
self._extra_params.update(
init_function_calling(
labels=labels,
tool_choices=tool_choices,
open_ai_api_call=open_ai_api_call,
)
)
def predict(self, text: str) -> Dict:
prompt = self.format_prompt(text)
generated_completion = self.language_model_calling(prompt=prompt)
final_prediction = self.extract_final_prediction(generated_completion)
return {"prediction": final_prediction}
@staticmethod
def format_prompt(text: str) -> List:
return [
{
"role": "user",
"content": [
{"type": "text", "text": text},
],
}
]
def language_model_calling(self, prompt: List):
return self._try_again(prompt=prompt)
def _try_again(self, prompt: List, retries: int = 0):
generated_completion = None
if retries > self.max_retries:
logging_message = f"Max retries exceeded: {retries}."
logging.warning(logging_message)
self._max_retries_counter += 1
return generated_completion
try:
generated_completion = self._inner_generate_fn(prompt=prompt)
return generated_completion
except:
time.sleep(5)
self._try_again(prompt=prompt, retries=retries + 1)
def extract_final_prediction(self, generated_completion) -> str:
if "claude" in self.model_name.lower():
if generated_completion is None:
final_prediction = None
elif generated_completion.content is None:
final_prediction = None
elif generated_completion.content[0].input is None:
final_prediction = None
else:
try:
final_prediction = generated_completion.content[0].input.get(
"category"
)
except:
# Case where the prediction is not a proper dictionary.
final_prediction = generated_completion.content[0].input
else:
# Something the completion is incomplete, thus we validate that components are there.
if generated_completion is None:
final_prediction = None
elif generated_completion.choices is None:
final_prediction = None
elif generated_completion.choices[0].message is None:
final_prediction = None
elif generated_completion.choices[0].message.tool_calls is None:
if generated_completion.choices[0].message.content is None:
final_prediction = None
else:
# No tools call, but potentially a response in the raw message content.
final_prediction = (
generated_completion.choices[0]
.message.content.strip()
.replace(")", "")
.strip()
)
else:
prediction = (
generated_completion.choices[0]
.message.tool_calls[0]
.function.arguments
)
try:
final_prediction = eval(prediction).get("category")
except:
# Case where the prediction is not a proper dictionary.
final_prediction = prediction
if final_prediction is None:
# Case were final prediction is None, thus we return -1 to be able to be converted
# as int if necessary (infer-case) or left as string (generate-case).
# Thus, in both case, it will not yield better results.
self._none_prediction_counter += 1
final_prediction = f"{NA_VALUE}"
elif "La réponse est" in final_prediction or ":" in final_prediction:
# To handle case where the LLM return the premise to the last query.
final_prediction = final_prediction.split(":")[-1].strip().replace(" ", "")
# Cases where the response is accompanied by other string elements, but it should be a single digit.
elif "0" in final_prediction:
final_prediction = "0"
elif "1" in final_prediction:
final_prediction = "1"
elif "2" in final_prediction:
final_prediction = "2"
elif "3" in final_prediction:
final_prediction = "3"
elif "4" in final_prediction:
final_prediction = "4"
elif "5" in final_prediction:
final_prediction = "5"
elif "6" in final_prediction:
final_prediction = "6"
elif "7" in final_prediction:
final_prediction = "7"
elif "8" in final_prediction:
final_prediction = "8"
elif "9" in final_prediction:
final_prediction = "9"
elif "10" in final_prediction:
final_prediction = "10"
elif "11" in final_prediction:
final_prediction = "11"
return final_prediction
def print_none(self) -> None:
if self._none_prediction_counter > 0:
logging_message = f"Number of None: {self._none_prediction_counter}."
logging.warning(logging_message)
if self._max_retries_counter > 0:
logging_message = f"Number of max retries exceeded occurrence: {self._max_retries_counter}."
logging.warning(logging_message)
|