Walter de Back
Change to instruct model
65f14d7
import os
from openai import OpenAI
import random
from loguru import logger
from dotenv import load_dotenv
load_dotenv()
from retry import retry
from typing import List, Dict, Optional
from tqdm import tqdm
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# speech ISO codes from: https://dl.fbaipublicfiles.com/mms/tts/all-tts-languages.html
countries = {"nl": {"language":"Dutch", "country":"the Netherlands", "flag": "๐Ÿ‡ณ๐Ÿ‡ฑ", "iso": "nld"},
"de": {"language":"German", "country":"Germany", "flag": "๐Ÿ‡ฉ๐Ÿ‡ช", "iso": "deu"},
"se": {"language":"Swedish", "country":"Sweden", "flag": "๐Ÿ‡ธ๐Ÿ‡ช", "iso": "swe"},
"en": {"language":"English", "country":"England", "flag": "๐Ÿ‡ฌ๐Ÿ‡ง", "iso": "eng"},
"fr": {"language":"French", "country":"France", "flag": "๐Ÿ‡ซ๐Ÿ‡ท", "iso": "fra"},
"it": {"language":"Italian", "country":"Italy", "flag": "๐Ÿ‡ฎ๐Ÿ‡น", "iso": None},
"es": {"language":"Spanish", "country":"Spain", "flag": "๐Ÿ‡ช๐Ÿ‡ธ", "iso": "spa"},
"pl": {"language":"Polish", "country":"Poland", "flag": "๐Ÿ‡ต๐Ÿ‡ฑ", "iso": "pol"},
"hu": {"language":"Hungarian", "country":"Hungary", "flag": "๐Ÿ‡ญ๐Ÿ‡บ", "iso": "hun"},
"fi": {"language":"Finnish", "country":"Finland", "flag": "๐Ÿ‡ซ๐Ÿ‡ฎ", "iso": "fin"},
}
moods = ["hilarious", "funny", "serious", "poetic"]
levels = {"easy": [4, 8],
"medium": [7, 16],
"complicated": [12, 25],
"very complicated": [20, 40]}
def get_languages():
return list(countries.keys())
def get_level_names():
return list(levels.keys())
class LLM():
def __init__(self):
load_dotenv()
self.min_length = 10
def parse_response(self, response:str) -> List[Dict[str, str]]:
logger.info(f"{response = }")
logger.info(f"Parsing response from OpenAI API.")
text = response
tokens = text.split("Sentence ")
logger.debug(f"Number of tokens: {len(tokens)}")
logger.debug(f"{tokens = }")
qas = []
for token in tokens:
if len(token) == 0:
continue
lines = token.split("\n")
#print(f"{lines = }")
original = lines[0][3:]
true_trans = ""
for line in lines:
if line.startswith("True: "):
true_trans = line[6:]
# get line starting with "False: "
false_trans = ""
for line in lines:
if line.startswith("False: "):
false_trans = line[7:]
# get line starting with "Funny: "
funny_trans = ""
for line in lines:
if line.startswith("Funny: "):
funny_trans = line[7:]
qa = {"original": original, "true": true_trans, "false": false_trans, "funny": funny_trans}
logger.debug(f"------------------------")
logger.debug(f"{qa = }")
logger.debug(f"------------------------")
if len(qa["original"]) > self.min_length and \
len(qa["true"]) > self.min_length and \
len(qa["false"]) > self.min_length and \
len(qa["funny"]) > self.min_length:
qas.append(qa)
logger.info(f"Returning {len(qas)} valid QA pairs.")
return qas
@retry(delay=1, backoff=2, max_delay=20, tries=10, logger=logger)
def generate(self, n:int,
input_country:str,
target_country:str,
level:str,
temperature:float=0.80,
streaming:bool=False) -> Dict[str, str]:
input_length = levels[level]
language = countries[input_country]['language']
country = countries[input_country]['country']
target_language = countries[target_country]['language']
mood = random.choice(moods)
logger.info(f"Generating {str(n)} QA pairs for {language} to {target_language} with level: {level}.")
assert temperature >= 0.0 and temperature <= 1.0, "temperature must be between 0 and 1"
response = client.chat.completions.create(
model="gpt-3.5-turbo-instruct", # https://community.openai.com/t/how-to-pass-prompt-to-the-chat-completions-create/592629/2
#model="gpt-3.5-turbo",
messages=[
{"role": "system", "content":
f"You are a helpful assistant at creating a translation game. \
You create sentences in {language} with a length that is exactly between {str(input_length[0])} to {str(input_length[1])} words. \
You create translations in {target_language}. Do not include a translation in the original sentence. \
Ensure that each translation is quite different from the other ones and reduce repetition. \
Formulate your answer in exactly this format: Sentence N: [X],\True: [A],\nFalse: [B],\nFunny: [C]."},
{"role": "assistant", "content":
f"Sometimes, create original sentences with words, locations, concepts and phrases are typical for {country}. Otherwise, create sentences that are unrelated to the country."},
{"role": "user", "content":
f"Generate {str(n)} {level} and {mood} sentences in {language} and 3 corresponding translations in {target_language}. \
You create 1 correct translation, 1 incorrect translations, and 1 which is very wrong and funny. \
"},
],
temperature=temperature,
stream=streaming)
logger.info(f"{response = }")
if streaming:
import time
start_time = time.time()
collected_chunks = []
collected_messages = []
for chunk in tqdm(response, total=250):
chunk_time = time.time() - start_time # calculate the time delay of the chunk
collected_chunks.append(chunk) # save the event response
chunk_message = chunk['choices'][0]['delta'] # extract the message
collected_messages.append(chunk_message) # save the message
#print(f"Message received {chunk_time:.2f} seconds after request: {chunk_message}") # print the delay and text
full_reply_content = ''.join([m.get('content', '') for m in collected_messages])
return full_reply_content
else:
return response.choices[0].message.content
def get_QAs(self, n:int, input_countries:List[str], target_country:str, level:str, debug:bool=False):
logger.info(f"Generating new sentences...")
# gr.Info(f"Generating new Q&A in {input_countries} with level: {level}.")
if debug:
return [ {"original": "The Netherlands is a country in Europe.",
"true": "Nederland is een land in Europa.",
"false": "Nederland is een land in Aziรซ.",
"funny": "Nederland is een aap in Europa."},
{"original": "Aap, noot, mies.",
"true": "Aap.",
"false": "Noot.",
"funny": "Mies."} ]
else:
return self.parse_response(
self.generate(n=n,
input_country=random.choice(input_countries),
target_country=target_country,
level=level))