test2 / app.py
sritang's picture
Update app.py
9aa38c0
#import numpy as np
import gradio as gr
#import random
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from torch import tensor as torch_tensor
from datasets import load_dataset
"""# import models"""
bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
bi_encoder.max_seq_length = 256 #Truncate long passages to 256 tokens
#The bi-encoder will retrieve top_k documents. We use a cross-encoder, to re-rank the results list to improve the quality
cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
"""# import datasets"""
dataset = load_dataset("gfhayworth/hack_policy", split='train')
mypassages = list(dataset.to_pandas()['psg'])
dataset_embed = load_dataset("gfhayworth/hack_policy_embed", split='train')
dataset_embed_pd = dataset_embed.to_pandas()
dataset_embed_pd
type(dataset_embed_pd)
mycorpus_embeddings = torch_tensor(dataset_embed_pd.values)