amaresh8053 commited on
Commit
9391674
·
verified ·
1 Parent(s): 9aba057

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -16
app.py CHANGED
@@ -14,7 +14,7 @@ nltk.download(['punkt', 'punkt_tab'], quiet=True)
14
 
15
  DEVICE = torch.device("cpu")
16
 
17
- CACHE_FILE = "ubuntu_data_cache.pt" # To get the Vocab from cache
18
  MODEL_FILE_WITH_ATTN = "ubuntu_chatbot_with_attn.pt" # trained model with attn
19
  MODEL_FILE_NO_ATTN = "ubuntu_chatbot_no_attn.pt" # trained model without attn
20
 
@@ -53,7 +53,7 @@ class Vocab:
53
 
54
  # ------------- load vocab from cache -------------
55
  print("Loading vocab...")
56
- data = torch.load("ubuntu_vocab_only.pt", map_location="cpu", weights_only=False)
57
  vocab = data["vocab"]
58
 
59
  PAD_IDX = vocab.word2idx["<PAD>"]
@@ -115,18 +115,12 @@ class Decoder_with_attn(nn.Module):
115
  self.norm = nn.LayerNorm(512)
116
 
117
  def forward(self, inp, hidden, enc_out):
118
- """
119
- inp: [B, 1] token IDs
120
- hidden: [2, B, 512] encoder hidden (num_layers, batch, hidden)
121
- enc_out:[B, T, 512]
122
- """
123
  e = self.dropout(self.emb(inp))
124
-
125
  # attention over encoder outputs
126
  energy = self.attn(enc_out)
127
  # use top layer hidden state for attention
128
- attn_scores = torch.bmm(hidden[-1].unsqueeze(1), energy.transpose(1, 2)) # [B,1,T]
129
- attn_weights = F.softmax(attn_scores.squeeze(1), dim=-1).unsqueeze(1) # [B,1,T]
130
  ctx = torch.bmm(attn_weights, enc_out)
131
 
132
  x = torch.cat((e, ctx), dim=-1)
@@ -160,10 +154,10 @@ class Model_with_attn(nn.Module):
160
 
161
  def forward(self, src, tgt, tf=0.5):
162
  enc_out, h = self.encoder(src)
163
- dec_in = tgt[:, 0] # <SOS>
164
  outs = []
165
  for t in range(1, tgt.size(1)):
166
- dec_in = dec_in.unsqueeze(1) # [B,1]
167
  out, h = self.decoder(dec_in, h, enc_out)
168
  outs.append(out)
169
  use_tf = random.random() < tf
@@ -178,10 +172,10 @@ class Model_no_attn(nn.Module):
178
 
179
  def forward(self, src, tgt, tf=0.5):
180
  enc_out, h = self.encoder(src)
181
- dec_in = tgt[:, 0] # <SOS>
182
  outs = []
183
  for t in range(1, tgt.size(1)):
184
- dec_in = dec_in.unsqueeze(1) # [B,1]
185
  out, h = self.decoder(dec_in, h)
186
  outs.append(out)
187
  use_tf = random.random() < tf
@@ -266,7 +260,6 @@ def beam_generate_v2(model, src_tensor, beam=5, max_len=50, alpha=0.7):
266
 
267
  # ------------- wrapper to go from user text → reply -------------
268
  def generate_reply_attn(user_text: str) -> str:
269
- # replicate notebook logic: reverse the input sentence
270
  user_text_rev = reverse(user_text)
271
  tokens = tokenize(user_text_rev)
272
  ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
@@ -277,7 +270,6 @@ def generate_reply_attn(user_text: str) -> str:
277
  return reply
278
 
279
  def generate_reply_no_attn(user_text: str) -> str:
280
- # replicate notebook logic: reverse the input sentence
281
  user_text_rev = reverse(user_text)
282
  tokens = tokenize(user_text_rev)
283
  ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
 
14
 
15
  DEVICE = torch.device("cpu")
16
 
17
+ VOCAB_FILE = "ubuntu_vocab_only.pt" # To get the Vocab from cache
18
  MODEL_FILE_WITH_ATTN = "ubuntu_chatbot_with_attn.pt" # trained model with attn
19
  MODEL_FILE_NO_ATTN = "ubuntu_chatbot_no_attn.pt" # trained model without attn
20
 
 
53
 
54
  # ------------- load vocab from cache -------------
55
  print("Loading vocab...")
56
+ data = torch.load(VOCAB_FILE, map_location="cpu", weights_only=False)
57
  vocab = data["vocab"]
58
 
59
  PAD_IDX = vocab.word2idx["<PAD>"]
 
115
  self.norm = nn.LayerNorm(512)
116
 
117
  def forward(self, inp, hidden, enc_out):
 
 
 
 
 
118
  e = self.dropout(self.emb(inp))
 
119
  # attention over encoder outputs
120
  energy = self.attn(enc_out)
121
  # use top layer hidden state for attention
122
+ attn_scores = torch.bmm(hidden[-1].unsqueeze(1), energy.transpose(1, 2))
123
+ attn_weights = F.softmax(attn_scores.squeeze(1), dim=-1).unsqueeze(1)
124
  ctx = torch.bmm(attn_weights, enc_out)
125
 
126
  x = torch.cat((e, ctx), dim=-1)
 
154
 
155
  def forward(self, src, tgt, tf=0.5):
156
  enc_out, h = self.encoder(src)
157
+ dec_in = tgt[:, 0]
158
  outs = []
159
  for t in range(1, tgt.size(1)):
160
+ dec_in = dec_in.unsqueeze(1)
161
  out, h = self.decoder(dec_in, h, enc_out)
162
  outs.append(out)
163
  use_tf = random.random() < tf
 
172
 
173
  def forward(self, src, tgt, tf=0.5):
174
  enc_out, h = self.encoder(src)
175
+ dec_in = tgt[:, 0]
176
  outs = []
177
  for t in range(1, tgt.size(1)):
178
+ dec_in = dec_in.unsqueeze(1)
179
  out, h = self.decoder(dec_in, h)
180
  outs.append(out)
181
  use_tf = random.random() < tf
 
260
 
261
  # ------------- wrapper to go from user text → reply -------------
262
  def generate_reply_attn(user_text: str) -> str:
 
263
  user_text_rev = reverse(user_text)
264
  tokens = tokenize(user_text_rev)
265
  ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
 
270
  return reply
271
 
272
  def generate_reply_no_attn(user_text: str) -> str:
 
273
  user_text_rev = reverse(user_text)
274
  tokens = tokenize(user_text_rev)
275
  ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]