reranker / app.py
caubetotbunggg's picture
Update app.py
bd13d99 verified
import torch
import gradio as gr
from transformers import AutoModelForSequenceClassification, AutoTokenizer
# Load model & tokenizer
MODEL_NAME = "AITeamVN/Vietnamese_Reranker"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
MAX_LENGTH = 2304
def rerank(batch_pairs):
if not batch_pairs:
return []
with torch.no_grad():
inputs = tokenizer(
batch_pairs,
padding=True,
truncation=True,
return_tensors="pt",
max_length=MAX_LENGTH
)
scores = model(**inputs, return_dict=True).logits.view(-1, ).float()
# convert tensor -> list float
return [float(s) for s in scores]
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🇻🇳 Vietnamese Reranker Demo")
inp = gr.JSON(
label="Nhập batch_pairs [[query, doc], ...]",
value=[
["Trí tuệ nhân tạo là gì?", "Trí tuệ nhân tạo là công nghệ giúp máy móc suy nghĩ và học hỏi như con người."],
["Trí tuệ nhân tạo là gì?", "Giấc ngủ giúp cơ thể nghỉ ngơi và hồi phục năng lượng."]
]
)
out = gr.JSON(label="Kết quả Rerank (list float)")
btn = gr.Button("Rerank")
btn.click(rerank, inputs=inp, outputs=out)
demo.launch()