asdfasdfdsafdsa commited on
Commit
2a2cec1
·
verified ·
1 Parent(s): 6f74e93

Fix Lang classes, CUDA compatibility, and config imports

Browse files
app.py CHANGED
@@ -13,43 +13,11 @@ from core.network import Network, MLMTransformerPretrain
13
  from model.backbone import get_visual_backbone
14
  from model.encoder import get_encoder
15
  from model.decoder import get_decoder
16
- from datasets.preprossing import SN
17
  from datasets.utils import get_combined_text, get_var_arg, get_text_index
18
  from datasets.operators import normalize_exp
19
  import datasets.diagram_aug as T_diagram
20
 
21
- # Language classes for vocabulary management
22
- class Lang:
23
- def __init__(self):
24
- self.word2index = {}
25
- self.word2count = {}
26
- self.index2word = {0: "PAD", 1: "SOS", 2: "EOS", 3: "UNK"}
27
- self.n_words = 4
28
- self.class_tag = ['PAD', 'QUE', 'VAR', 'NUM', 'SEP']
29
- self.sect_tag = ['PAD', 'TEXT', 'STRU', 'SEM']
30
-
31
- def add_sentence(self, sentence):
32
- for word in sentence.split(' '):
33
- self.add_word(word)
34
-
35
- def add_word(self, word):
36
- if word not in self.word2index:
37
- self.word2index[word] = self.n_words
38
- self.word2count[word] = 1
39
- self.index2word[self.n_words] = word
40
- self.n_words += 1
41
- else:
42
- self.word2count[word] += 1
43
-
44
- def indexes_from_sentence(self, sentence, var_values=None, arg_values=None):
45
- indexes = []
46
- for word in sentence.split(' '):
47
- if word in self.word2index:
48
- indexes.append(self.word2index[word])
49
- else:
50
- indexes.append(3) # UNK
51
- return indexes
52
-
53
  # Configuration class
54
  class Config:
55
  def __init__(self):
@@ -89,18 +57,9 @@ class Config:
89
  def load_model():
90
  cfg = Config()
91
 
92
- # Load vocabularies
93
- src_lang = Lang()
94
- tgt_lang = Lang()
95
-
96
- # Load vocab files
97
- with open('./vocab/vocab_src.txt', 'r') as f:
98
- for line in f:
99
- src_lang.add_word(line.strip())
100
-
101
- with open('./vocab/vocab_tgt.txt', 'r') as f:
102
- for line in f:
103
- tgt_lang.add_word(line.strip())
104
 
105
  # Create model
106
  model = Network(cfg, src_lang, tgt_lang)
 
13
  from model.backbone import get_visual_backbone
14
  from model.encoder import get_encoder
15
  from model.decoder import get_decoder
16
+ from datasets.preprossing import SN, SrcLang, TgtLang
17
  from datasets.utils import get_combined_text, get_var_arg, get_text_index
18
  from datasets.operators import normalize_exp
19
  import datasets.diagram_aug as T_diagram
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # Configuration class
22
  class Config:
23
  def __init__(self):
 
57
  def load_model():
58
  cfg = Config()
59
 
60
+ # Load vocabularies using proper Lang classes
61
+ src_lang = SrcLang('./vocab/vocab_src.txt')
62
+ tgt_lang = TgtLang('./vocab/vocab_tgt.txt')
 
 
 
 
 
 
 
 
 
63
 
64
  # Create model
65
  model = Network(cfg, src_lang, tgt_lang)
loss/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  from .loss import *
2
- from config import criterion_list
 
3
 
4
 
5
  def get_criterion(args):
 
1
  from .loss import *
2
+
3
+ criterion_list = ["CrossEntropy", "FocalLoss", "MaskedCrossEntropy"]
4
 
5
 
6
  def get_criterion(args):
model/classifier/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  from .classifier_ops import *
2
- from config import classifier_list
 
3
 
4
 
5
  def get_classifier(args):
 
1
  from .classifier_ops import *
2
+
3
+ classifier_list = ["FCNorm", "CosNorm", "DotProduct", "DistFC"]
4
 
5
 
6
  def get_classifier(args):
model/decoder/rnn_decoder.py CHANGED
@@ -23,7 +23,8 @@ class DecoderRNN(nn.Module):
23
  self.attn = Attn(cfg.encoder_hidden_size, cfg.decoder_hidden_size)
24
  self.score = Score(cfg.encoder_hidden_size+cfg.decoder_hidden_size, cfg.decoder_embedding_size)
25
  # predefined constant
26
- self.no_var_id = torch.arange(self.var_start).unsqueeze(0).cuda()
 
27
  self.cfg = cfg
28
 
29
  def get_var_encoder_outputs(self, encoder_outputs, var_pos):
@@ -127,15 +128,15 @@ class DecoderRNN(nn.Module):
127
  for i in range(self.cfg.max_output_len):
128
  # initial varible
129
  if i==0:
130
- input_token = torch.LongTensor([[self.sos_id]]*rem_size).cuda() # rem_size x 1
131
  rnn_hidden = problem_output[:, sample_id:sample_id+1].repeat(1, rem_size, 1) # layer_num x rem_size x H
132
- current_score = torch.FloatTensor([[0.0]]*rem_size).cuda() # rem_size x 1
133
  current_exp_list = [[]]*rem_size
134
  else:
135
- input_token = torch.LongTensor(token_list).unsqueeze(1).cuda()
136
  rnn_hidden = rnn_hidden[:, cand_list]
137
  rem_size = len(exp_list)
138
- current_score = torch.FloatTensor(score_list[:rem_size]).unsqueeze(1).cuda()
139
  current_exp_list = exp_list
140
 
141
  # input embedding
 
23
  self.attn = Attn(cfg.encoder_hidden_size, cfg.decoder_hidden_size)
24
  self.score = Score(cfg.encoder_hidden_size+cfg.decoder_hidden_size, cfg.decoder_embedding_size)
25
  # predefined constant
26
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ self.no_var_id = torch.arange(self.var_start).unsqueeze(0).to(self.device)
28
  self.cfg = cfg
29
 
30
  def get_var_encoder_outputs(self, encoder_outputs, var_pos):
 
128
  for i in range(self.cfg.max_output_len):
129
  # initial varible
130
  if i==0:
131
+ input_token = torch.LongTensor([[self.sos_id]]*rem_size).to(self.device) # rem_size x 1
132
  rnn_hidden = problem_output[:, sample_id:sample_id+1].repeat(1, rem_size, 1) # layer_num x rem_size x H
133
+ current_score = torch.FloatTensor([[0.0]]*rem_size).to(self.device) # rem_size x 1
134
  current_exp_list = [[]]*rem_size
135
  else:
136
+ input_token = torch.LongTensor(token_list).unsqueeze(1).to(self.device)
137
  rnn_hidden = rnn_hidden[:, cand_list]
138
  rem_size = len(exp_list)
139
+ current_score = torch.FloatTensor(score_list[:rem_size]).unsqueeze(1).to(self.device)
140
  current_exp_list = exp_list
141
 
142
  # input embedding
utils/utils.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  import torch
3
  from utils.lr_scheduler import WarmupMultiStepLR
4
- from config import *
5
  import datetime
6
  import torch.distributed as dist
7
  from datasets.operators import result_compute, normalize_exp
 
1
  import os
2
  import torch
3
  from utils.lr_scheduler import WarmupMultiStepLR
 
4
  import datetime
5
  import torch.distributed as dist
6
  from datasets.operators import result_compute, normalize_exp