| import numpy as np | |
| import pickle | |
| import embedding | |
| import random | |
| import embed_set | |
| import net | |
| from tqdm import tqdm | |
| from tensorflow.keras.models import load_model | |
| top_p = 1 | |
| class SetLine: | |
| def __init__(self, name, inp): | |
| self.name = name | |
| self.inp = embedding.getvec(name) | |
| with open("set.pckl", "rb") as f: dset = pickle.load(f) | |
| model = load_model("net.h5") | |
| def top_closest_vectors(input_vector, top_p=1): | |
| distances = [(np.linalg.norm((neuron.inp - input_vector)), ind) for ind, neuron in enumerate(dset)] | |
| closest_indices = sorted(distances, reverse=False, key=lambda x:x[0])[:top_p] | |
| return closest_indices | |
| def generate(text): | |
| vecs = 3*[np.zeros(net.vec_size),] + [embedding.getvec(x) for x in text.split("\n")] | |
| vecs = vecs[-3:] | |
| vecs = np.array([vecs,]) | |
| rvec = model.predict(vecs)[0] | |
| return dset[random.choice(top_closest_vectors(rvec))[1]].name | |