keethu commited on
Commit
52c20e5
·
verified ·
1 Parent(s): acdfced

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -5
app.py CHANGED
@@ -1,12 +1,18 @@
1
  import torch
2
  import torch.nn as nn
3
- from transformers import BertTokenizer, BertModel
4
  import gradio as gr
5
 
6
- model_name = "keethu/bert-emotion-classifier"
 
 
 
7
  tokenizer = BertTokenizer.from_pretrained(model_name)
8
- bert_model = BertModel.from_pretrained(model_name)
9
 
 
 
 
 
10
  class BERTClassifier(nn.Module):
11
  def __init__(self, bert_model, num_labels=5, dropout=0.3):
12
  super(BERTClassifier, self).__init__()
@@ -21,8 +27,19 @@ class BERTClassifier(nn.Module):
21
  logits = self.classifier(pooled_output)
22
  return logits
23
 
24
- model = BERTClassifier(bert_model, num_labels=5, dropout=0.3)
25
- model.load_state_dict(torch.load(f"{model_name}/pytorch_model.bin", map_location='cpu'))
 
 
 
 
 
 
 
 
 
 
 
26
  model.eval()
27
 
28
  emotion_labels = ['anger', 'fear', 'joy', 'sadness', 'surprise']
 
1
  import torch
2
  import torch.nn as nn
3
+ from transformers import BertTokenizer, BertModel, BertForSequenceClassification
4
  import gradio as gr
5
 
6
+ # Your model repo
7
+ model_name = "keethu/bert-emotion-classifier"
8
+
9
+ # Load tokenizer
10
  tokenizer = BertTokenizer.from_pretrained(model_name)
 
11
 
12
+ # Load base BERT model
13
+ base_bert = BertModel.from_pretrained(model_name)
14
+
15
+ # Define your classifier architecture (same as training)
16
  class BERTClassifier(nn.Module):
17
  def __init__(self, bert_model, num_labels=5, dropout=0.3):
18
  super(BERTClassifier, self).__init__()
 
27
  logits = self.classifier(pooled_output)
28
  return logits
29
 
30
+ # Create model instance
31
+ model = BERTClassifier(base_bert, num_labels=5, dropout=0.3)
32
+
33
+ # Load the trained weights - USE from_pretrained properly
34
+ from huggingface_hub import hf_hub_download
35
+ import os
36
+
37
+ # Download the model file
38
+ model_path = hf_hub_download(repo_id=model_name, filename="pytorch_model.bin")
39
+
40
+ # Load state dict
41
+ state_dict = torch.load(model_path, map_location='cpu')
42
+ model.load_state_dict(state_dict)
43
  model.eval()
44
 
45
  emotion_labels = ['anger', 'fear', 'joy', 'sadness', 'surprise']