Spaces:
Sleeping
Sleeping
| from openai import OpenAI | |
| import os | |
| from loguru import logger | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| from retry import retry | |
| from typing import List, Dict | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| country = {"nl": {"language":"Dutch", "country":"Netherlands"}, | |
| "de": {"language":"German", "country":"Germany"}, | |
| "se": {"language":"Swedish", "country":"Sweden"}, | |
| "en": {"language":"English", "country":"England"} | |
| } | |
| levels = {"easy": [3, 6], | |
| "medium": [5, 12], | |
| "hard": [12, 20]} | |
| def get_languages(): | |
| return list(country.keys()) | |
| def get_level(): | |
| return list(levels.keys()) | |
| def parse_response(response:dict, min_length:int=10) -> List[Dict[str, str]]: | |
| logger.debug(f"{response = }") | |
| logger.info(f"Parsing response from OpenAI API.") | |
| text = response.choices[0].message.content | |
| tokens = text.split("Sentence ") | |
| logger.debug(f"Number of tokens: {len(tokens)}") | |
| logger.debug(f"{tokens = }") | |
| qas = [] | |
| for token in tokens: | |
| if len(token) == 0: | |
| continue | |
| lines = token.split("\n") | |
| #print(f"{lines = }") | |
| original = lines[0][3:] | |
| true_trans = "" | |
| for line in lines: | |
| if line.startswith("True: "): | |
| true_trans = line[6:] | |
| # get line starting with "False: " | |
| false_trans = "" | |
| for line in lines: | |
| if line.startswith("False: "): | |
| false_trans = line[7:] | |
| # get line starting with "Funny: " | |
| funny_trans = "" | |
| for line in lines: | |
| if line.startswith("Funny: "): | |
| funny_trans = line[7:] | |
| qa = {"original": original, "true": true_trans, "false": false_trans, "funny": funny_trans} | |
| logger.debug(f"------------------------") | |
| logger.debug(f"{qa = }") | |
| logger.debug(f"------------------------") | |
| if len(qa["original"]) > min_length and len(qa["true"]) > min_length and len(qa["false"]) > min_length and len(qa["funny"]) > min_length: | |
| qas.append(qa) | |
| logger.info(f"Returning {len(qas)} valid QA pairs.") | |
| return qas | |
| def generate(n:int, input_country:str, target_country:str, level:str, temperature:float=0.80) -> Dict[str, str]: | |
| input_length = levels[level] | |
| logger.info(f"Generating {str(n)} QA pairs for {country[input_country]['language']} to {country[target_country]['language']} with level: {level}.") | |
| assert temperature >= 0.0 and temperature <= 1.0, "temperature must be between 0 and 1" | |
| response = client.chat.completions.create(model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "system", "content": | |
| f"You are a helpful assistant at creating a translation game.\ | |
| You create sentences in {str(country[input_country]['language'])} with a length that is exactly between {str(input_length[0])} to {str(input_length[1])} words.\ | |
| You create translations in {str(country[target_country]['language'])}.\ | |
| Formulate your answer in exactly this format: Sentence N: [X],\True: [A],\nFalse: [B],\nFunny: [C]."}, | |
| {"role": "assistant", "content": | |
| f"Create original sentences with words, locations, concepts and phrases are typical for {str(country[input_country]['country'])} ."}, | |
| {"role": "user", "content": | |
| f"Generate {str(n)} funny sentences. You create 1 correct translation, 1 incorrect translations, and 1 which very wrong and funny."}, | |
| ], | |
| temperature=temperature) | |
| return response | |
| # n = 10 | |
| # input_country = "nl" | |
| # target_country = "de" | |
| # input_length = "easy" | |
| def get_QAs(n:int, input_country:str, target_country:str, level:str, debug:bool): | |
| if debug: | |
| return [ {"original": "The Netherlands is a country in Europe.", | |
| "true": "Nederland is een land in Europa.", | |
| "false": "Nederland is een land in Azië.", | |
| "funny": "Nederland is een aap in Europa."}, | |
| {"original": "Aap, noot, mies.", | |
| "true": "Aap.", | |
| "false": "Noot.", | |
| "funny": "Mies."} ] | |
| else: | |
| return parse_response(generate(n=n, input_country=input_country, target_country=target_country, level=level)) |