Spaces:
Build error
Build error
| from model import encoder_text | |
| import torch, clip, random | |
| import numpy as np | |
| device = torch.device("cpu") | |
| from words import words | |
| ########## SENTENCE PART ####################################################### | |
| voyelles = ["a","e","i","o","u"] | |
| links = list(words.keys())[1:] | |
| def link_text(part,nextWord): | |
| ### Check if we need to write "... a", "... an", "..." | |
| if (len(part["link"]) > 0) and (part["link"][-1] == "a"): | |
| voyelleStart = (nextWord[0] in voyelles) | |
| plural = (nextWord[-1] == "s" and nextWord[-2] != "s") or (nextWord in ["nothing","hair","vampire teeth","something"]) | |
| else: | |
| voyelleStart, plural = False, False | |
| return (part["link"][:-2] if plural else part["link"] + ("n" if voyelleStart else "")) | |
| def part_text(part): | |
| l = link_text(part,part["word"]) | |
| return l + (" " if len(l)>0 else "") + part["word"] | |
| def compute_embeddings(part,var_dict,prefix,batch_size=64): | |
| target = part["word"] | |
| possibleWords = list(set(words[part["link"]]) - set([target]+var_dict["found_words"])) | |
| if len(possibleWords) > (batch_size-1): possibleWords = np.random.choice(list(possibleWords),batch_size-1,replace=False).tolist() | |
| possibleWords.append(target) | |
| ### Compute all classes & embeddings for current sentence part | |
| part["classes"] = [prefix + link_text(part,w) + (" " if len(link_text(part,w))>0 else "") + w for w in possibleWords] | |
| with torch.no_grad(): | |
| embeddings = encoder_text(clip.tokenize(part["classes"]).to(device)) | |
| embeddings /= embeddings.norm(dim=-1, keepdim=True) | |
| part["embeddings"] = embeddings.tolist() | |
| ########## SENTENCE ############################################################ | |
| def iniSentence(var_dict,input="",first_game=False): | |
| var_dict["found_words"] = [] | |
| var_dict["parts"] = [] | |
| var_dict["step"] = 0 | |
| prefix = "" | |
| N = (2 if var_dict["difficulty"] == 1 else 1) | |
| if first_game: | |
| link = "a drawing of a" | |
| part = {"link":link,"word":"cat","classes":[],"embeddings":[]} | |
| var_dict["parts"].append(part) | |
| compute_embeddings(part, var_dict, prefix) | |
| prefix += part_text(part) + " " | |
| link = "with a" | |
| part = {"link":link,"word":"face","classes":[],"embeddings":[]} | |
| var_dict["parts"].append(part) | |
| compute_embeddings(part, var_dict, prefix) | |
| prefix += part_text(part) + " " | |
| else: | |
| ##### Generating Random Sentence | |
| link = "a drawing of a" | |
| part = {"link":link,"word":np.random.choice(words[link]),"classes":[],"embeddings":[]} | |
| var_dict["parts"].append(part) | |
| compute_embeddings(part, var_dict, prefix) | |
| prefix += part_text(part) + " " | |
| for i in range(N-1): | |
| link = np.random.choice(links) | |
| part = {"link":link,"word":np.random.choice(words[link][1:]),"classes":[],"embeddings":[]} | |
| var_dict["parts"].append(part) | |
| compute_embeddings(part, var_dict, prefix) | |
| prefix += part_text(part) + " " | |
| var_dict["target_sentence"] = prefix[:-1] # Target sentence is prefix without the last space | |
| setState(var_dict) | |
| return var_dict["target_sentence"] | |
| def prevState(var_dict): | |
| if len(var_dict["prev_steps"]) > 0: var_dict["step"] = var_dict["prev_steps"].pop(-1) | |
| else: var_dict["step"] = 0 | |
| var_dict["revertedState"] = True | |
| setState(var_dict) | |
| def setState(var_dict): | |
| var_dict["found_words"] = var_dict["found_words"][:var_dict["step"]] | |
| var_dict["guessed_sentence"] = "" | |
| for i in range(var_dict["step"]): | |
| var_dict["guessed_sentence"] += part_text(var_dict["parts"][i]) + " " | |
| def updateState(var_dict, preds): | |
| if not var_dict["revertedState"]: var_dict["prev_steps"].append(var_dict["step"]) | |
| else: var_dict["revertedState"] = False | |
| ### Check if the current part has been guessed | |
| part = var_dict["parts"][var_dict["step"]] | |
| idx_of_nothing = -1 | |
| if ("nothing" in preds[0]): idx_of_nothing = 0 | |
| elif ("nothing" in preds[1]): idx_of_nothing = 1 | |
| elif ("nothing" in preds[2]): idx_of_nothing = 2 | |
| idx_of_guess = -1 | |
| if (part["classes"][-1] == preds[0]): idx_of_guess = 0 | |
| elif (part["classes"][-1] == preds[1]): idx_of_guess = 1 | |
| elif (part["classes"][-1] == preds[2]): idx_of_guess = 2 | |
| if not var_dict["win"] and (idx_of_guess > idx_of_nothing): | |
| var_dict["step"] += 1 | |
| var_dict["found_words"].append(part["word"]) | |
| var_dict["win"] = var_dict["step"] == len(var_dict["parts"]) | |
| setState(var_dict) | |
| if var_dict["win"]: return 1 | |
| else: return 0 | |
| elif not var_dict["win"]: return -1 | |
| else: return 1 | |