aliokaccound's picture
Update app.py
fbde8a9 verified
import numpy as np
import onnxruntime as ort
from transformers import PreTrainedTokenizerFast
import gradio as gr
# -----------------------------
# Tokenizer (Local from repo)
# -----------------------------
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json")
# -----------------------------
# ONNX Model Session
# -----------------------------
session = ort.InferenceSession("model.onnx")
# -----------------------------
# Labels (আপনার মডেলের অনুযায়ী)
# -----------------------------
LABELS = [
"ASK_PRICE",
"ASK_DETAILS",
"ASK_DISCOUNT",
"ORDER_CREATE",
"META_UNCLASSIFIED",
"COMPLAINT_PRODUCT_DAMAGED",
"RESOLUTION_REQUEST_RETURN_REFUND",
"COMPLAINT_POOR_QUALITY",
"COMPLAINT_WRONG_ITEM",
"META_CLOSING",
"META_GREETING",
"META_FEEDBACK_POSITIVE"
]
# -----------------------------
# Prediction Function
# -----------------------------
def predict(text: str, top_k: int = 5):
# Tokenize input
inputs = tokenizer(text, return_tensors="np", padding=True)
# ONNX inference
outputs = session.run(
None,
{
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"]
}
)
logits = outputs[0][0] # assume outputs[0] contains logits
scores = np.exp(logits) / np.sum(np.exp(logits)) # softmax
# Top-k, ensure it doesn't exceed LABELS length
safe_top_k = min(top_k, len(LABELS))
top_indices = np.argsort(scores)[::-1][:safe_top_k]
result = [{"label": LABELS[i], "score": float(scores[i])} for i in top_indices]
return result
# -----------------------------
# Gradio Interface
# -----------------------------
iface = gr.Interface(
fn=predict,
inputs=[gr.Textbox(lines=2, placeholder="Type your Bengali text here..."),
gr.Slider(1, len(LABELS), value=5, step=1, label="Top K")],
outputs="json",
title="Intent Router ONNX Model",
description="Enter Bengali text to get top_k intent predictions"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)