Walter de Back
Set retry delay to 1 sec
a813bef
from openai import OpenAI
import os
from loguru import logger
from dotenv import load_dotenv
load_dotenv()
from retry import retry
from typing import List, Dict
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
country = {"nl": {"language":"Dutch", "country":"Netherlands"},
"de": {"language":"German", "country":"Germany"},
"se": {"language":"Swedish", "country":"Sweden"},
"en": {"language":"English", "country":"England"}
}
levels = {"easy": [3, 6],
"medium": [5, 12],
"hard": [12, 20]}
def get_languages():
return list(country.keys())
def get_level():
return list(levels.keys())
def parse_response(response:dict, min_length:int=10) -> List[Dict[str, str]]:
logger.debug(f"{response = }")
logger.info(f"Parsing response from OpenAI API.")
text = response.choices[0].message.content
tokens = text.split("Sentence ")
logger.debug(f"Number of tokens: {len(tokens)}")
logger.debug(f"{tokens = }")
qas = []
for token in tokens:
if len(token) == 0:
continue
lines = token.split("\n")
#print(f"{lines = }")
original = lines[0][3:]
true_trans = ""
for line in lines:
if line.startswith("True: "):
true_trans = line[6:]
# get line starting with "False: "
false_trans = ""
for line in lines:
if line.startswith("False: "):
false_trans = line[7:]
# get line starting with "Funny: "
funny_trans = ""
for line in lines:
if line.startswith("Funny: "):
funny_trans = line[7:]
qa = {"original": original, "true": true_trans, "false": false_trans, "funny": funny_trans}
logger.debug(f"------------------------")
logger.debug(f"{qa = }")
logger.debug(f"------------------------")
if len(qa["original"]) > min_length and len(qa["true"]) > min_length and len(qa["false"]) > min_length and len(qa["funny"]) > min_length:
qas.append(qa)
logger.info(f"Returning {len(qas)} valid QA pairs.")
return qas
@retry(delay=1, backoff=2, max_delay=20, tries=10, logger=logger)
def generate(n:int, input_country:str, target_country:str, level:str, temperature:float=0.80) -> Dict[str, str]:
input_length = levels[level]
logger.info(f"Generating {str(n)} QA pairs for {country[input_country]['language']} to {country[target_country]['language']} with level: {level}.")
assert temperature >= 0.0 and temperature <= 1.0, "temperature must be between 0 and 1"
response = client.chat.completions.create(model="gpt-3.5-turbo",
messages=[
{"role": "system", "content":
f"You are a helpful assistant at creating a translation game.\
You create sentences in {str(country[input_country]['language'])} with a length that is exactly between {str(input_length[0])} to {str(input_length[1])} words.\
You create translations in {str(country[target_country]['language'])}.\
Formulate your answer in exactly this format: Sentence N: [X],\True: [A],\nFalse: [B],\nFunny: [C]."},
{"role": "assistant", "content":
f"Create original sentences with words, locations, concepts and phrases are typical for {str(country[input_country]['country'])} ."},
{"role": "user", "content":
f"Generate {str(n)} funny sentences. You create 1 correct translation, 1 incorrect translations, and 1 which very wrong and funny."},
],
temperature=temperature)
return response
# n = 10
# input_country = "nl"
# target_country = "de"
# input_length = "easy"
def get_QAs(n:int, input_country:str, target_country:str, level:str, debug:bool):
if debug:
return [ {"original": "The Netherlands is a country in Europe.",
"true": "Nederland is een land in Europa.",
"false": "Nederland is een land in Azië.",
"funny": "Nederland is een aap in Europa."},
{"original": "Aap, noot, mies.",
"true": "Aap.",
"false": "Noot.",
"funny": "Mies."} ]
else:
return parse_response(generate(n=n, input_country=input_country, target_country=target_country, level=level))