Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import DistilBertTokenizer, AutoModelForSequenceClassification | |
| import torch.nn.functional as F | |
| # ------------------------------- | |
| # Configuração do modelo | |
| # ------------------------------- | |
| MODEL_NAME = "hermanshid/distilbert-id-law" | |
| # Tokenizer compatível do DistilBERT base | |
| tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") | |
| # Modelo customizado | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| model.eval() | |
| # ------------------------------- | |
| # Labels dummy para visualização | |
| # ------------------------------- | |
| num_labels = model.config.num_labels | |
| labels = [f"Class {i}" for i in range(num_labels)] | |
| # ------------------------------- | |
| # Função de inferência | |
| # ------------------------------- | |
| def respond(message, history, system_message=None, max_tokens=None, temperature=None, top_p=None, hf_token=None): | |
| inputs = tokenizer(message, return_tensors="pt", truncation=True, padding=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| # Número de classes | |
| num_labels = model.config.num_labels | |
| if num_labels == 1: | |
| prob = torch.sigmoid(logits).item() | |
| output = f"Predicted logit: {logits.item():.4f}\nProbability (sigmoid): {prob:.4f}" | |
| else: | |
| probs = F.softmax(logits, dim=-1).squeeze().tolist() | |
| pred_index = torch.argmax(logits, dim=-1).item() | |
| pred_label = labels[pred_index] | |
| output = f"Predicted label: {pred_label}\nProbabilities:\n" | |
| for i, p in enumerate(probs): | |
| output += f"{labels[i]}: {p:.4f}\n" | |
| output += f"\nLogits: {logits.squeeze().tolist()}" | |
| return output | |
| # ------------------------------- | |
| # Interface estilo Chatbot | |
| # ------------------------------- | |
| chatbot = gr.ChatInterface( | |
| fn=respond, | |
| type="messages", | |
| additional_inputs=[ | |
| gr.Textbox(value="Classificador jurídico DistilBERT.", label="System message"), | |
| gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)") | |
| ] | |
| ) | |
| with gr.Blocks() as demo: | |
| with gr.Sidebar(): | |
| gr.LoginButton() # opcional | |
| chatbot.render() | |
| # ------------------------------- | |
| # Inicializa a aplicação | |
| # ------------------------------- | |
| if __name__ == "__main__": | |
| demo.launch() |