jpmendes's picture
Update app.py
1faf52d verified
import gradio as gr
import torch
from transformers import DistilBertTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F
# -------------------------------
# Configuração do modelo
# -------------------------------
MODEL_NAME = "hermanshid/distilbert-id-law"
# Tokenizer compatível do DistilBERT base
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
# Modelo customizado
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# -------------------------------
# Labels dummy para visualização
# -------------------------------
num_labels = model.config.num_labels
labels = [f"Class {i}" for i in range(num_labels)]
# -------------------------------
# Função de inferência
# -------------------------------
def respond(message, history, system_message=None, max_tokens=None, temperature=None, top_p=None, hf_token=None):
inputs = tokenizer(message, return_tensors="pt", truncation=True, padding=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
# Número de classes
num_labels = model.config.num_labels
if num_labels == 1:
prob = torch.sigmoid(logits).item()
output = f"Predicted logit: {logits.item():.4f}\nProbability (sigmoid): {prob:.4f}"
else:
probs = F.softmax(logits, dim=-1).squeeze().tolist()
pred_index = torch.argmax(logits, dim=-1).item()
pred_label = labels[pred_index]
output = f"Predicted label: {pred_label}\nProbabilities:\n"
for i, p in enumerate(probs):
output += f"{labels[i]}: {p:.4f}\n"
output += f"\nLogits: {logits.squeeze().tolist()}"
return output
# -------------------------------
# Interface estilo Chatbot
# -------------------------------
chatbot = gr.ChatInterface(
fn=respond,
type="messages",
additional_inputs=[
gr.Textbox(value="Classificador jurídico DistilBERT.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
]
)
with gr.Blocks() as demo:
with gr.Sidebar():
gr.LoginButton() # opcional
chatbot.render()
# -------------------------------
# Inicializa a aplicação
# -------------------------------
if __name__ == "__main__":
demo.launch()