aamirtaymoor's picture
Upload 313 files
b03fd59 verified
#!/usr/bin/env python3
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import io
import numpy as np
import collections
def load_vectors(fname, maxload=200000, norm=True, center=False, verbose=True):
if verbose:
print("Loading vectors from %s" % fname)
fin = io.open(fname, 'r', encoding='utf-8', newline='\n', errors='ignore')
n, d = map(int, fin.readline().split())
if maxload > 0:
n = min(n, maxload)
x = np.zeros([n, d])
words = []
for i, line in enumerate(fin):
if i >= n:
break
tokens = line.rstrip().split(' ')
words.append(tokens[0])
v = np.array(tokens[1:], dtype=float)
x[i, :] = v
if norm:
x /= np.linalg.norm(x, axis=1)[:, np.newaxis] + 1e-8
if center:
x -= x.mean(axis=0)[np.newaxis, :]
x /= np.linalg.norm(x, axis=1)[:, np.newaxis] + 1e-8
if verbose:
print("%d word vectors loaded" % (len(words)))
return words, x
def idx(words):
w2i = {}
for i, w in enumerate(words):
if w not in w2i:
w2i[w] = i
return w2i
def save_vectors(fname, x, words):
n, d = x.shape
fout = io.open(fname, 'w', encoding='utf-8')
fout.write(u"%d %d\n" % (n, d))
for i in range(n):
fout.write(words[i] + " " + " ".join(map(lambda a: "%.4f" % a, x[i, :])) + "\n")
fout.close()
def save_matrix(fname, x):
n, d = x.shape
fout = io.open(fname, 'w', encoding='utf-8')
fout.write(u"%d %d\n" % (n, d))
for i in range(n):
fout.write(" ".join(map(lambda a: "%.4f" % a, x[i, :])) + "\n")
fout.close()
def procrustes(X_src, Y_tgt):
U, s, V = np.linalg.svd(np.dot(Y_tgt.T, X_src))
return np.dot(U, V)
def select_vectors_from_pairs(x_src, y_tgt, pairs):
n = len(pairs)
d = x_src.shape[1]
x = np.zeros([n, d])
y = np.zeros([n, d])
for k, ij in enumerate(pairs):
i, j = ij
x[k, :] = x_src[i, :]
y[k, :] = y_tgt[j, :]
return x, y
def load_lexicon(filename, words_src, words_tgt, verbose=True):
f = io.open(filename, 'r', encoding='utf-8')
lexicon = collections.defaultdict(set)
idx_src , idx_tgt = idx(words_src), idx(words_tgt)
vocab = set()
for line in f:
word_src, word_tgt = line.split()
if word_src in idx_src and word_tgt in idx_tgt:
lexicon[idx_src[word_src]].add(idx_tgt[word_tgt])
vocab.add(word_src)
if verbose:
coverage = len(lexicon) / float(len(vocab))
print("Coverage of source vocab: %.4f" % (coverage))
return lexicon, float(len(vocab))
def load_pairs(filename, idx_src, idx_tgt, verbose=True):
f = io.open(filename, 'r', encoding='utf-8')
pairs = []
tot = 0
for line in f:
a, b = line.rstrip().split(' ')
tot += 1
if a in idx_src and b in idx_tgt:
pairs.append((idx_src[a], idx_tgt[b]))
if verbose:
coverage = (1.0 * len(pairs)) / tot
print("Found pairs for training: %d - Total pairs in file: %d - Coverage of pairs: %.4f" % (len(pairs), tot, coverage))
return pairs
def compute_nn_accuracy(x_src, x_tgt, lexicon, bsz=100, lexicon_size=-1):
if lexicon_size < 0:
lexicon_size = len(lexicon)
idx_src = list(lexicon.keys())
acc = 0.0
x_src /= np.linalg.norm(x_src, axis=1)[:, np.newaxis] + 1e-8
x_tgt /= np.linalg.norm(x_tgt, axis=1)[:, np.newaxis] + 1e-8
for i in range(0, len(idx_src), bsz):
e = min(i + bsz, len(idx_src))
scores = np.dot(x_tgt, x_src[idx_src[i:e]].T)
pred = scores.argmax(axis=0)
for j in range(i, e):
if pred[j - i] in lexicon[idx_src[j]]:
acc += 1.0
return acc / lexicon_size
def compute_csls_accuracy(x_src, x_tgt, lexicon, lexicon_size=-1, k=10, bsz=1024):
if lexicon_size < 0:
lexicon_size = len(lexicon)
idx_src = list(lexicon.keys())
x_src /= np.linalg.norm(x_src, axis=1)[:, np.newaxis] + 1e-8
x_tgt /= np.linalg.norm(x_tgt, axis=1)[:, np.newaxis] + 1e-8
sr = x_src[list(idx_src)]
sc = np.dot(sr, x_tgt.T)
similarities = 2 * sc
sc2 = np.zeros(x_tgt.shape[0])
for i in range(0, x_tgt.shape[0], bsz):
j = min(i + bsz, x_tgt.shape[0])
sc_batch = np.dot(x_tgt[i:j, :], x_src.T)
dotprod = np.partition(sc_batch, -k, axis=1)[:, -k:]
sc2[i:j] = np.mean(dotprod, axis=1)
similarities -= sc2[np.newaxis, :]
nn = np.argmax(similarities, axis=1).tolist()
correct = 0.0
for k in range(0, len(lexicon)):
if nn[k] in lexicon[idx_src[k]]:
correct += 1.0
return correct / lexicon_size