Jet-12138 commited on
Commit
1d71490
·
verified ·
1 Parent(s): cbce16f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -6
app.py CHANGED
@@ -2,10 +2,11 @@ import torch
2
  import torch.nn.functional as F
3
  from transformers import BertTokenizer
4
  import gradio as gr
 
5
 
6
- from model import CommentMTLModel # 导入你自定义的模型
7
 
8
- # Set device, including MPS support for Mac
9
  if torch.backends.mps.is_available():
10
  device = torch.device("mps")
11
  elif torch.cuda.is_available():
@@ -16,8 +17,17 @@ else:
16
  # Load tokenizer
17
  tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
18
 
19
- # Load model
20
- model = CommentMTLModel(config_path="config.json")
 
 
 
 
 
 
 
 
 
21
  model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
22
  model.to(device)
23
  model.eval()
@@ -32,7 +42,10 @@ def analyse_comment(comment):
32
  inputs = {k: v.to(device) for k, v in inputs.items()}
33
 
34
  with torch.no_grad():
35
- sentiment_logits, toxicity_logits = model(**inputs)
 
 
 
36
 
37
  # Process sentiment
38
  sentiment_probs = F.softmax(sentiment_logits, dim=1)
@@ -58,7 +71,7 @@ iface = gr.Interface(
58
  gr.Label(num_top_classes=1, label="Predicted Toxicity")
59
  ],
60
  title="Comment Sentiment and Toxicity Classifier",
61
- description="This tool classifies the sentiment and the most probable type of toxicity in a given comment. It utilises a custom fine-tuned BERT model. Developed for academic demonstration purposes in Australia."
62
  )
63
 
64
  iface.launch()
 
2
  import torch.nn.functional as F
3
  from transformers import BertTokenizer
4
  import gradio as gr
5
+ import json
6
 
7
+ from model import CommentMTLModel # 用你的类名
8
 
9
+ # Set device, including MPS
10
  if torch.backends.mps.is_available():
11
  device = torch.device("mps")
12
  elif torch.cuda.is_available():
 
17
  # Load tokenizer
18
  tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
19
 
20
+ # Load config values manually
21
+ with open("config.json", "r") as f:
22
+ config_data = json.load(f)
23
+
24
+ # Create model
25
+ model = CommentMTLModel(
26
+ model_name="bert-base-uncased",
27
+ num_sentiment_labels=config_data["num_sentiment_labels"],
28
+ num_toxicity_labels=config_data["num_toxicity_labels"],
29
+ dropout_prob=config_data.get("dropout_prob", 0.1)
30
+ )
31
  model.load_state_dict(torch.load("pytorch_model.bin", map_location=device))
32
  model.to(device)
33
  model.eval()
 
42
  inputs = {k: v.to(device) for k, v in inputs.items()}
43
 
44
  with torch.no_grad():
45
+ outputs = model(**inputs)
46
+
47
+ sentiment_logits = outputs["sentiment_logits"]
48
+ toxicity_logits = outputs["toxicity_logits"]
49
 
50
  # Process sentiment
51
  sentiment_probs = F.softmax(sentiment_logits, dim=1)
 
71
  gr.Label(num_top_classes=1, label="Predicted Toxicity")
72
  ],
73
  title="Comment Sentiment and Toxicity Classifier",
74
+ description="This tool classifies the sentiment and the most probable type of toxicity in a given comment. It utilises a custom multi-task learning BERT model. Developed for academic demonstration purposes in Australia."
75
  )
76
 
77
  iface.launch()