jpmendes commited on
Commit
2ff098a
·
verified ·
1 Parent(s): fe7c5f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -12
app.py CHANGED
@@ -1,40 +1,48 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import DistilBertTokenizer, AutoModelForSequenceClassification
 
4
 
5
  # -------------------------------
6
  # Configuração do modelo
7
  # -------------------------------
8
  MODEL_NAME = "hermanshid/distilbert-id-law"
9
 
10
- # Como o modelo não tem tokenizer, usamos um tokenizer de DistilBERT base compatível
11
  tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
12
 
13
- # Carrega o modelo customizado
14
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
15
-
16
- # Definir dispositivo (CPU suficiente para Space gratuito)
17
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
  model.to(device)
19
  model.eval()
20
 
 
 
 
 
 
 
21
  # -------------------------------
22
  # Função de inferência
23
  # -------------------------------
24
  def respond(message, history, system_message=None, max_tokens=None, temperature=None, top_p=None, hf_token=None):
25
- """
26
- Recebe texto do usuário, retorna a classe prevista pelo DistilBERT.
27
- Os parâmetros history, system_message, max_tokens etc são mantidos
28
- apenas para compatibilidade com ChatInterface.
29
- """
30
  inputs = tokenizer(message, return_tensors="pt", truncation=True, padding=True)
31
  inputs = {k: v.to(device) for k, v in inputs.items()}
32
 
33
  with torch.no_grad():
34
  logits = model(**inputs).logits
 
 
 
35
 
36
- pred_class = torch.argmax(logits, dim=-1).item()
37
- return f"Predicted class: {pred_class}"
 
 
 
 
 
38
 
39
  # -------------------------------
40
  # Interface estilo Chatbot
@@ -59,4 +67,4 @@ with gr.Blocks() as demo:
59
  # Inicializa a aplicação
60
  # -------------------------------
61
  if __name__ == "__main__":
62
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from transformers import DistilBertTokenizer, AutoModelForSequenceClassification
4
+ import torch.nn.functional as F
5
 
6
  # -------------------------------
7
  # Configuração do modelo
8
  # -------------------------------
9
  MODEL_NAME = "hermanshid/distilbert-id-law"
10
 
11
+ # Tokenizer compatível do DistilBERT base
12
  tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
13
 
14
+ # Modelo customizado
15
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
 
 
16
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  model.to(device)
18
  model.eval()
19
 
20
+ # -------------------------------
21
+ # Labels dummy para visualização
22
+ # -------------------------------
23
+ num_labels = model.config.num_labels
24
+ labels = [f"Class {i}" for i in range(num_labels)]
25
+
26
  # -------------------------------
27
  # Função de inferência
28
  # -------------------------------
29
  def respond(message, history, system_message=None, max_tokens=None, temperature=None, top_p=None, hf_token=None):
 
 
 
 
 
30
  inputs = tokenizer(message, return_tensors="pt", truncation=True, padding=True)
31
  inputs = {k: v.to(device) for k, v in inputs.items()}
32
 
33
  with torch.no_grad():
34
  logits = model(**inputs).logits
35
+ probs = F.softmax(logits, dim=-1).squeeze().tolist() # probabilidades
36
+ pred_index = torch.argmax(logits, dim=-1).item()
37
+ pred_label = labels[pred_index]
38
 
39
+ # Monta mensagem de saída detalhada
40
+ output = f"Predicted label: {pred_label}\n\n"
41
+ output += "Probabilities:\n"
42
+ for i, p in enumerate(probs):
43
+ output += f"{labels[i]}: {p:.4f}\n"
44
+ output += f"\nLogits: {logits.squeeze().tolist()}"
45
+ return output
46
 
47
  # -------------------------------
48
  # Interface estilo Chatbot
 
67
  # Inicializa a aplicação
68
  # -------------------------------
69
  if __name__ == "__main__":
70
+ demo.launch()