xjlulu commited on
Commit
fdb5dd9
·
1 Parent(s): dfe9225
Files changed (1) hide show
  1. app.py +46 -77
app.py CHANGED
@@ -1,58 +1,42 @@
1
  import gradio as gr
2
  from typing import Dict, List
3
-
4
  import torch
5
- import torch.nn as nn
6
- import torch.optim as optim
7
- from torch.utils.data import DataLoader
8
-
9
  import json
10
  import pickle
11
  from pathlib import Path
12
-
13
- from dataset import SeqClsDataset
14
  from utils import Vocab
15
  from model import SeqClassifier
 
16
 
17
- import ipdb
18
-
19
  max_len = 128
20
  hidden_size = 256
21
  num_layers = 2
22
  dropout = 0.1
23
  bidirectional = True
24
- lr = 1e-3
25
- batch_size = 64
26
- num_epoch = 5
27
-
28
-
29
- TRAIN = "train"
30
- DEV = "eval"
31
- TEST = "test"
32
- SPLITS = [TRAIN, DEV, TEST]
33
 
34
  device = "cpu"
35
- data_dir = Path("./data/intent/")
36
  ckpt_dir = Path("./ckpt/intent/")
37
  cache_dir = Path("./cache/intent/")
38
- # Before executing, place intent2idx.json, embeddings.pt, vocab.pkl, and utils.py in /content
 
39
  with open(cache_dir / "vocab.pkl", "rb") as f:
40
  vocab: Vocab = pickle.load(f)
 
41
  intent_idx_path = cache_dir / "intent2idx.json"
42
  intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())
43
- data_paths = {split: data_dir / f"{split}.json" for split in SPLITS}
44
- data = {split: json.loads(path.read_text()) for split, path in data_paths.items()}
45
- datasets: Dict[str, SeqClsDataset] = {
46
- split: SeqClsDataset(split_data, vocab, intent2idx, max_len)
47
- for split, split_data in data.items()
48
- }
49
- #ipdb.set_trace()
50
- test_loader = DataLoader(datasets['test'], batch_size=batch_size, shuffle=False)
51
- embeddings = torch.load(cache_dir / "embeddings.pt")
52
  embeddings.to(device)
53
 
54
- # Load the best model after training
55
- # Initialize a new model with the same architecture
56
  best_model = SeqClassifier(
57
  embeddings=embeddings,
58
  hidden_size=hidden_size,
@@ -65,63 +49,48 @@ best_model = SeqClassifier(
65
  # Define the path to the checkpoint file
66
  ckpt_path = ckpt_dir / "model_checkpoint.pth"
67
 
68
- # Load the model's state_dict and optimizer's state_dict from the checkpoint
69
- checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
70
-
71
  # Load the model's weights
72
- best_model.load_state_dict(checkpoint['model_state_dict']).to(device)
73
-
74
- # Reinitialize the optimizer with the model's parameters and load its state
75
- '''weight_decay = 1e-5
76
- optimizer = optim.Adam(best_model.parameters(), lr=lr, weight_decay=weight_decay)
77
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])'''
78
-
79
- # Retrieve the epoch number from the checkpoint
80
- epoch = checkpoint['epoch']
81
 
82
- # Set the best model to evaluation mode
83
  best_model.eval()
84
 
85
-
86
- dic_intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())
87
- dic_idx2label = {idx: intent for intent, idx in dic_intent2idx.items()}
88
-
89
- def Tidx2label(idx: int):
90
- return dic_idx2label[idx]
91
-
92
- with open(cache_dir / "vocab.pkl", "rb") as f:
93
- vocab: Vocab = pickle.load(f)
94
-
95
- # 把句子做成embeddings的索引
96
  def collate_fn(texts: str) -> torch.tensor:
97
- # 提取所有樣本的文本數據和標籤數據
98
  texts = texts.split()
99
-
100
- # 使用 vocab 將文本數據轉換為整數索引序列,並指定最大長度
101
- encoded_texts = vocab.encode_batch([[text for text in texts]], to_len=max_len)
102
-
103
- # 將整數索引序列轉換為 PyTorch 張量
104
  encoded_text = torch.tensor(encoded_texts)
105
  return encoded_text
106
 
107
-
108
  def classify(text):
109
  encoded_text = collate_fn(text).to(device)
110
- output = best_model(encoded_text[0])
111
  Predicted_class = torch.argmax(output).item()
112
- prediction = Tidx2label(Predicted_class)
113
- return prediction
114
 
115
- demo = gr.Interface(
116
- fn=classify,
117
- inputs=gr.Textbox(placeholder="請輸入一段文字..."),
118
- outputs="label",
119
- interpretation="default",
120
- examples=[
121
- ["Take me to church"],
122
- ["tell me what to call you"],
123
- ["could you be a person"]
124
- ]
125
- )
126
 
127
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from typing import Dict, List
 
3
  import torch
4
+ torch.backends.cudnn.enabled = False
 
 
 
5
  import json
6
  import pickle
7
  from pathlib import Path
 
 
8
  from utils import Vocab
9
  from model import SeqClassifier
10
+ from seafoam import Seafoam
11
 
12
+ # Set model parameters
 
13
  max_len = 128
14
  hidden_size = 256
15
  num_layers = 2
16
  dropout = 0.1
17
  bidirectional = True
 
 
 
 
 
 
 
 
 
18
 
19
  device = "cpu"
 
20
  ckpt_dir = Path("./ckpt/intent/")
21
  cache_dir = Path("./cache/intent/")
22
+
23
+ # Load vocabulary and intent index mapping
24
  with open(cache_dir / "vocab.pkl", "rb") as f:
25
  vocab: Vocab = pickle.load(f)
26
+
27
  intent_idx_path = cache_dir / "intent2idx.json"
28
  intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())
29
+ __idx2label = {idx: intent for intent, idx in intent2idx.items()}
30
+
31
+ def idx2label(idx: int):
32
+ return __idx2label[idx]
33
+
34
+ # Set embedding layer size
35
+ embeddings_size = (6491, 300)
36
+ embeddings = torch.empty(embeddings_size)
 
37
  embeddings.to(device)
38
 
39
+ # Load the best model
 
40
  best_model = SeqClassifier(
41
  embeddings=embeddings,
42
  hidden_size=hidden_size,
 
49
  # Define the path to the checkpoint file
50
  ckpt_path = ckpt_dir / "model_checkpoint.pth"
51
 
 
 
 
52
  # Load the model's weights
53
+ checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu'))
54
+ best_model.load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
 
 
 
55
 
56
+ # Set the model to evaluation mode
57
  best_model.eval()
58
 
59
+ # Processing function to convert text to embedding indices
 
 
 
 
 
 
 
 
 
 
60
  def collate_fn(texts: str) -> torch.tensor:
 
61
  texts = texts.split()
62
+ encoded_texts = vocab.encode_batch([[text for text in texts]], to_len=max_len)[0]
 
 
 
 
63
  encoded_text = torch.tensor(encoded_texts)
64
  return encoded_text
65
 
66
+ # Classification function
67
  def classify(text):
68
  encoded_text = collate_fn(text).to(device)
69
+ output = best_model(encoded_text)
70
  Predicted_class = torch.argmax(output).item()
71
+ prediction = idx2label(Predicted_class)
72
+ return "Category:" + prediction
73
 
74
+ # Use the Seafoam theme
75
+ seafoam = Seafoam()
 
 
 
 
 
 
 
 
 
76
 
77
+ # Create a Gradio interface
78
+ demo = gr.Interface(
79
+ fn=classify,
80
+ inputs=gr.Textbox(placeholder="Please enter a text..."),
81
+ outputs="label",
82
+ interpretation="none",
83
+ live=False,
84
+ enable_queue=True,
85
+ examples=[
86
+ ["please set an alarm for mid day"],
87
+ ["tell lydia and laura where i am located"],
88
+ ["what's the deal with my health care"]
89
+ ],
90
+ title="Text Intent Classification Demo",
91
+ description="This demo uses a model to classify text into different intents or categories. Enter a text and see the classification result.",
92
+ theme=seafoam
93
+ )
94
+
95
+ # Launch the Gradio interface
96
+ demo.launch(share=True)