Spaces:
Runtime error
Runtime error
| import pickle | |
| import numpy as np | |
| import faiss | |
| from tqdm.auto import tqdm | |
| def sample_tf(x,y,ndim=1000): | |
| ''' | |
| input: tf.x,tf.y, ndim | |
| return: n-dim tf values | |
| ''' | |
| t = np.linspace(0,1,ndim) | |
| return np.piecewise(t,[t>=xx for xx in x],y) | |
| tf_train = pickle.load(open('./data/trainTF.pkl','rb')) | |
| tf = [] | |
| for i in tqdm(range(len(tf_train))): | |
| tf_i = tf_train[i] | |
| tf.append(sample_tf(tf_i['x'],tf_i['y'])) | |
| d = 1000 | |
| tf = np.array(tf).astype(np.float32) | |
| ncentroids = 1000 | |
| niter = 200 | |
| verbose = True | |
| kmeans = faiss.Kmeans(d, ncentroids, niter=niter, verbose=verbose,gpu=True) | |
| kmeans.train(tf) | |
| centroids = kmeans.centroids | |
| index = faiss.IndexFlatL2(d) | |
| index.add(tf) | |
| nNN = 1000 | |
| D, I = index.search (kmeans.centroids, nNN) | |
| np.save(f'./data/centroids_train.npy',centroids) | |
| np.save(f'./data/clusters_train.npy',I) |