Spaces:
Sleeping
Sleeping
| import os | |
| from openai import OpenAI | |
| import random | |
| from loguru import logger | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| from retry import retry | |
| from typing import List, Dict, Optional | |
| from tqdm import tqdm | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| # speech ISO codes from: https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html | |
| countries = {"nl": {"language":"Dutch", "country":"the Netherlands", "flag": "๐ณ๐ฑ", "iso": "nld"}, | |
| "de": {"language":"German", "country":"Germany", "flag": "๐ฉ๐ช", "iso": "deu"}, | |
| "se": {"language":"Swedish", "country":"Sweden", "flag": "๐ธ๐ช", "iso": "swe"}, | |
| "en": {"language":"English", "country":"England", "flag": "๐ฌ๐ง", "iso": "eng"}, | |
| "fr": {"language":"French", "country":"France", "flag": "๐ซ๐ท", "iso": "fra"}, | |
| "it": {"language":"Italian", "country":"Italy", "flag": "๐ฎ๐น", "iso": None}, | |
| "es": {"language":"Spanish", "country":"Spain", "flag": "๐ช๐ธ", "iso": "spa"}, | |
| "pl": {"language":"Polish", "country":"Poland", "flag": "๐ต๐ฑ", "iso": "pol"}, | |
| "hu": {"language":"Hungarian", "country":"Hungary", "flag": "๐ญ๐บ", "iso": "hun"}, | |
| "fi": {"language":"Finnish", "country":"Finland", "flag": "๐ซ๐ฎ", "iso": "fin"}, | |
| } | |
| moods = ["hilarious", "funny", "serious", "poetic"] | |
| levels = {"easy": [4, 8], | |
| "medium": [7, 16], | |
| "complicated": [12, 25], | |
| "very complicated": [20, 40]} | |
| def get_languages(): | |
| return list(countries.keys()) | |
| def get_level_names(): | |
| return list(levels.keys()) | |
| class LLM(): | |
| def __init__(self): | |
| load_dotenv() | |
| self.min_length = 10 | |
| def parse_response(self, response:str) -> List[Dict[str, str]]: | |
| logger.info(f"{response = }") | |
| logger.info(f"Parsing response from OpenAI API.") | |
| text = response | |
| tokens = text.split("Sentence ") | |
| logger.debug(f"Number of tokens: {len(tokens)}") | |
| logger.debug(f"{tokens = }") | |
| qas = [] | |
| for token in tokens: | |
| if len(token) == 0: | |
| continue | |
| lines = token.split("\n") | |
| #print(f"{lines = }") | |
| original = lines[0][3:] | |
| true_trans = "" | |
| for line in lines: | |
| if line.startswith("True: "): | |
| true_trans = line[6:] | |
| # get line starting with "False: " | |
| false_trans = "" | |
| for line in lines: | |
| if line.startswith("False: "): | |
| false_trans = line[7:] | |
| # get line starting with "Funny: " | |
| funny_trans = "" | |
| for line in lines: | |
| if line.startswith("Funny: "): | |
| funny_trans = line[7:] | |
| qa = {"original": original, "true": true_trans, "false": false_trans, "funny": funny_trans} | |
| logger.debug(f"------------------------") | |
| logger.debug(f"{qa = }") | |
| logger.debug(f"------------------------") | |
| if len(qa["original"]) > self.min_length and \ | |
| len(qa["true"]) > self.min_length and \ | |
| len(qa["false"]) > self.min_length and \ | |
| len(qa["funny"]) > self.min_length: | |
| qas.append(qa) | |
| logger.info(f"Returning {len(qas)} valid QA pairs.") | |
| return qas | |
| def generate(self, n:int, | |
| input_country:str, | |
| target_country:str, | |
| level:str, | |
| temperature:float=0.80, | |
| streaming:bool=False) -> Dict[str, str]: | |
| input_length = levels[level] | |
| language = countries[input_country]['language'] | |
| country = countries[input_country]['country'] | |
| target_language = countries[target_country]['language'] | |
| mood = random.choice(moods) | |
| logger.info(f"Generating {str(n)} QA pairs for {language} to {target_language} with level: {level}.") | |
| assert temperature >= 0.0 and temperature <= 1.0, "temperature must be between 0 and 1" | |
| response = client.chat.completions.create( | |
| model="gpt-3.5-turbo-instruct", # https://community.openai.com/t/how-to-pass-prompt-to-the-chat-completions-create/592629/2 | |
| #model="gpt-3.5-turbo", | |
| messages=[ | |
| {"role": "system", "content": | |
| f"You are a helpful assistant at creating a translation game. \ | |
| You create sentences in {language} with a length that is exactly between {str(input_length[0])} to {str(input_length[1])} words. \ | |
| You create translations in {target_language}. Do not include a translation in the original sentence. \ | |
| Ensure that each translation is quite different from the other ones and reduce repetition. \ | |
| Formulate your answer in exactly this format: Sentence N: [X],\True: [A],\nFalse: [B],\nFunny: [C]."}, | |
| {"role": "assistant", "content": | |
| f"Sometimes, create original sentences with words, locations, concepts and phrases are typical for {country}. Otherwise, create sentences that are unrelated to the country."}, | |
| {"role": "user", "content": | |
| f"Generate {str(n)} {level} and {mood} sentences in {language} and 3 corresponding translations in {target_language}. \ | |
| You create 1 correct translation, 1 incorrect translations, and 1 which is very wrong and funny. \ | |
| "}, | |
| ], | |
| temperature=temperature, | |
| stream=streaming) | |
| logger.info(f"{response = }") | |
| if streaming: | |
| import time | |
| start_time = time.time() | |
| collected_chunks = [] | |
| collected_messages = [] | |
| for chunk in tqdm(response, total=250): | |
| chunk_time = time.time() - start_time # calculate the time delay of the chunk | |
| collected_chunks.append(chunk) # save the event response | |
| chunk_message = chunk['choices'][0]['delta'] # extract the message | |
| collected_messages.append(chunk_message) # save the message | |
| #print(f"Message received {chunk_time:.2f} seconds after request: {chunk_message}") # print the delay and text | |
| full_reply_content = ''.join([m.get('content', '') for m in collected_messages]) | |
| return full_reply_content | |
| else: | |
| return response.choices[0].message.content | |
| def get_QAs(self, n:int, input_countries:List[str], target_country:str, level:str, debug:bool=False): | |
| logger.info(f"Generating new sentences...") | |
| # gr.Info(f"Generating new Q&A in {input_countries} with level: {level}.") | |
| if debug: | |
| return [ {"original": "The Netherlands is a country in Europe.", | |
| "true": "Nederland is een land in Europa.", | |
| "false": "Nederland is een land in Aziรซ.", | |
| "funny": "Nederland is een aap in Europa."}, | |
| {"original": "Aap, noot, mies.", | |
| "true": "Aap.", | |
| "false": "Noot.", | |
| "funny": "Mies."} ] | |
| else: | |
| return self.parse_response( | |
| self.generate(n=n, | |
| input_country=random.choice(input_countries), | |
| target_country=target_country, | |
| level=level)) |