Commit ·
4208ff2
1
Parent(s): 47eb45b
Change model and Add multi_label option
Browse files
app.py
CHANGED
|
@@ -5,7 +5,8 @@ import pandas as pd
|
|
| 5 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 6 |
from InstructorEmbedding import INSTRUCTOR
|
| 7 |
|
| 8 |
-
pipe = pipeline(model="facebook/bart-large-mnli")
|
|
|
|
| 9 |
model = INSTRUCTOR('hkunlp/instructor-large')
|
| 10 |
|
| 11 |
df = pd.read_csv('intent.csv', delimiter=';')
|
|
@@ -20,14 +21,14 @@ data = [
|
|
| 20 |
corpus_embeddings = model.encode(data)
|
| 21 |
|
| 22 |
|
| 23 |
-
def predict(question, lower_threshold, tags):
|
| 24 |
query = [['Represent the question for retrieving supporting documents: ',question]]
|
| 25 |
query_embeddings = model.encode(query)
|
| 26 |
similarities = cosine_similarity(query_embeddings,corpus_embeddings)
|
| 27 |
retrieved_doc_id = np.argmax(similarities)
|
| 28 |
|
| 29 |
if similarities[0][retrieved_doc_id] < float(lower_threshold):
|
| 30 |
-
ans = pipe(question, candidate_labels=[x.strip() for x in tags.split(",") if x.strip()!=""])
|
| 31 |
ans['query_similarity_score'] = similarities[0][retrieved_doc_id]
|
| 32 |
return ans
|
| 33 |
return {"sequence" : data[retrieved_doc_id][-1], 'query_similarity_score' : similarities[0][retrieved_doc_id]}
|
|
@@ -37,7 +38,7 @@ def predict(question, lower_threshold, tags):
|
|
| 37 |
|
| 38 |
|
| 39 |
gr.Interface(fn=predict,
|
| 40 |
-
inputs=["text", "number", "text"],
|
| 41 |
outputs="json").launch()
|
| 42 |
|
| 43 |
|
|
|
|
| 5 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 6 |
from InstructorEmbedding import INSTRUCTOR
|
| 7 |
|
| 8 |
+
# pipe = pipeline(model="facebook/bart-large-mnli")
|
| 9 |
+
pipe = pipeline("zero-shot-classification", model="MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7")
|
| 10 |
model = INSTRUCTOR('hkunlp/instructor-large')
|
| 11 |
|
| 12 |
df = pd.read_csv('intent.csv', delimiter=';')
|
|
|
|
| 21 |
corpus_embeddings = model.encode(data)
|
| 22 |
|
| 23 |
|
| 24 |
+
def predict(question, lower_threshold, tags, multi_label):
|
| 25 |
query = [['Represent the question for retrieving supporting documents: ',question]]
|
| 26 |
query_embeddings = model.encode(query)
|
| 27 |
similarities = cosine_similarity(query_embeddings,corpus_embeddings)
|
| 28 |
retrieved_doc_id = np.argmax(similarities)
|
| 29 |
|
| 30 |
if similarities[0][retrieved_doc_id] < float(lower_threshold):
|
| 31 |
+
ans = pipe(question, candidate_labels=[x.strip() for x in tags.split(",") if x.strip()!=""], multi_label=multi_label)
|
| 32 |
ans['query_similarity_score'] = similarities[0][retrieved_doc_id]
|
| 33 |
return ans
|
| 34 |
return {"sequence" : data[retrieved_doc_id][-1], 'query_similarity_score' : similarities[0][retrieved_doc_id]}
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
gr.Interface(fn=predict,
|
| 41 |
+
inputs=["text", "number", "text", "boolean"],
|
| 42 |
outputs="json").launch()
|
| 43 |
|
| 44 |
|