File size: 3,357 Bytes
93ecd47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn as nn
from torch import tensor 
from transformers import BertModel, BertTokenizer
import pandas as pd
import requests


class EmbeddingModel(nn.Module):
    def __init__(self, bertName = "bert-base-uncased"): # other bert models can also be supported
        super().__init__()
        self.bertName = bertName
        # use BERT model
        self.tokenizer = BertTokenizer.from_pretrained(self.bertName)
        self.model = BertModel.from_pretrained(self.bertName)        
       
    def forward(self, s, device = "cuda"):
        # get tokens, which also include attention_mask
        tokens = self.tokenizer(s, return_tensors='pt', padding = "max_length", truncation = True, max_length = 256).to(device)
        
        # get token embeddings
        output = self.model(**tokens)
        tokens_embeddings = output.last_hidden_state
        #print("tokens_embeddings:" + str(tokens_embeddings.shape))
        
        # mean pooling to get text embedding
        embeddings = tokens_embeddings * tokens.attention_mask[...,None] # [B, T, emb]
        #print("embeddings:" + str(embeddings.shape))
        
        embeddings = embeddings.sum(1) # [B, emb]
        valid_tokens = tokens.attention_mask.sum(1) # [B]
        embeddings = embeddings / valid_tokens[...,None] # [B, emb]    
        
        return embeddings

    # from scratch: nn.CosineSimilarity(dim = 1)(q,a)
    def cos_score(self, q, a): 
        q_norm = q / (q.pow(2).sum(dim=1, keepdim=True).pow(0.5))
        r_norm = a / (a.pow(2).sum(dim=1, keepdim=True).pow(0.5))
        return (q_norm @ r_norm.T).diagonal()
    
# contrastive training
class TrainModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.m = EmbeddingModel("bert-base-uncased")

    def forward(self, s1, s2, score):        
        cos_score = self.m.cos_score(self.m(s1), self.m(s2))
        loss = nn.MSELoss()(cos_score, score)
        return loss, cos_score
    
def searchWiki(s):
    response = requests.get(
            'https://en.wikipedia.org/w/api.php',
            params={
                'action': 'query',
                'format': 'json',
                'titles': s,
                'prop': 'extracts',
                'exintro': True,
                'explaintext': True,
            }
        ).json()
    page = next(iter(response['query']['pages'].values()))
    return page['extract'].replace("\n","")

# sentence chunking
def chunk(w):
    return w.split(".")

def generate_chunk_data(concepts):
    wiki_data = [searchWiki(c).replace("\n","") for c in concepts]
    chunk_data = []
    for w in wiki_data:
        chunk_data = chunk_data + chunk(w) 

    chunk_data = [c.strip()+"." for c in chunk_data]
    while '.' in chunk_data:
        chunk_data.remove('.')
    
    return chunk_data

def generate_chunk_emb(m, chunk_data):
    with torch.no_grad():
        emb = m(chunk_data, device = "cpu")
    return emb

def search_document(s, chunk_data, chunk_emb, m, topk=3):
    question = [s]
    with torch.no_grad():
        result_score = m.cos_score(m(question, device = "cpu").expand(chunk_emb.shape),chunk_emb)
    print(result_score)
    _,idxs = torch.topk(result_score,topk)
    print([result_score.flatten()[idx] for idx in idxs.flatten().tolist()])
    return [chunk_data[idx] for idx in idxs.flatten().tolist()]