SentenceEmbedding / myTextEmbedding.py
wb-droid's picture
new changes.
6ae201e
raw
history blame
3.36 kB
import torch
import torch.nn as nn
from torch import tensor
from transformers import BertModel, BertTokenizer
import pandas as pd
import requests
class EmbeddingModel(nn.Module):
def __init__(self, bertName = "bert-base-uncased"): # other bert models can also be supported
super().__init__()
self.bertName = bertName
# use BERT model
self.tokenizer = BertTokenizer.from_pretrained(self.bertName)
self.model = BertModel.from_pretrained(self.bertName)
def forward(self, s, device = "cuda"):
# get tokens, which also include attention_mask
tokens = self.tokenizer(s, return_tensors='pt', padding = "max_length", truncation = True, max_length = 256).to(device)
# get token embeddings
output = self.model(**tokens)
tokens_embeddings = output.last_hidden_state
#print("tokens_embeddings:" + str(tokens_embeddings.shape))
# mean pooling to get text embedding
embeddings = tokens_embeddings * tokens.attention_mask[...,None] # [B, T, emb]
#print("embeddings:" + str(embeddings.shape))
embeddings = embeddings.sum(1) # [B, emb]
valid_tokens = tokens.attention_mask.sum(1) # [B]
embeddings = embeddings / valid_tokens[...,None] # [B, emb]
return embeddings
# from scratch: nn.CosineSimilarity(dim = 1)(q,a)
def cos_score(self, q, a):
q_norm = q / (q.pow(2).sum(dim=1, keepdim=True).pow(0.5))
r_norm = a / (a.pow(2).sum(dim=1, keepdim=True).pow(0.5))
return (q_norm @ r_norm.T).diagonal()
# contrastive training
class TrainModel(nn.Module):
def __init__(self):
super().__init__()
self.m = EmbeddingModel("bert-base-uncased")
def forward(self, s1, s2, score):
cos_score = self.m.cos_score(self.m(s1), self.m(s2))
loss = nn.MSELoss()(cos_score, score)
return loss, cos_score
def searchWiki(s):
response = requests.get(
'https://en.wikipedia.org/w/api.php',
params={
'action': 'query',
'format': 'json',
'titles': s,
'prop': 'extracts',
'exintro': True,
'explaintext': True,
}
).json()
page = next(iter(response['query']['pages'].values()))
return page['extract'].replace("\n","")
# sentence chunking
def chunk(w):
return w.split(".")
def generate_chunk_data(concepts):
wiki_data = [searchWiki(c).replace("\n","") for c in concepts]
chunk_data = []
for w in wiki_data:
chunk_data = chunk_data + chunk(w)
chunk_data = [c.strip()+"." for c in chunk_data]
while '.' in chunk_data:
chunk_data.remove('.')
return chunk_data
def generate_chunk_emb(m, chunk_data):
with torch.no_grad():
emb = m(chunk_data, device = "cpu")
return emb
def search_document(s, chunk_data, chunk_emb, m, topk=3):
question = [s]
with torch.no_grad():
result_score = m.cos_score(m(question, device = "cpu").expand(chunk_emb.shape),chunk_emb)
print(result_score)
_,idxs = torch.topk(result_score,topk)
print([result_score.flatten()[idx] for idx in idxs.flatten().tolist()])
return [chunk_data[idx] for idx in idxs.flatten().tolist()]