File size: 913 Bytes
9aa38c0
 
 
 
 
4d47d7a
b35e25d
9aa38c0
4d47d7a
9aa38c0
 
0cf52e4
9aa38c0
 
 
 
 
 
 
32e5396
9aa38c0
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
#import numpy as np
import gradio as gr
#import random
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from torch import tensor as torch_tensor
from datasets import load_dataset

"""# import models"""

bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_seq_length = 256     #Truncate long passages to 256 tokens

#The bi-encoder will retrieve top_k documents. We use a cross-encoder, to re-rank the results list to improve the quality
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')

"""# import datasets"""

dataset = load_dataset("gfhayworth/hack_policy", split='train')
mypassages = list(dataset.to_pandas()['psg'])

dataset_embed = load_dataset("gfhayworth/hack_policy_embed", split='train')
dataset_embed_pd = dataset_embed.to_pandas()
dataset_embed_pd
type(dataset_embed_pd)
mycorpus_embeddings = torch_tensor(dataset_embed_pd.values)