Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,7 +14,7 @@ nltk.download(['punkt', 'punkt_tab'], quiet=True)
|
|
| 14 |
|
| 15 |
DEVICE = torch.device("cpu")
|
| 16 |
|
| 17 |
-
|
| 18 |
MODEL_FILE_WITH_ATTN = "ubuntu_chatbot_with_attn.pt" # trained model with attn
|
| 19 |
MODEL_FILE_NO_ATTN = "ubuntu_chatbot_no_attn.pt" # trained model without attn
|
| 20 |
|
|
@@ -53,7 +53,7 @@ class Vocab:
|
|
| 53 |
|
| 54 |
# ------------- load vocab from cache -------------
|
| 55 |
print("Loading vocab...")
|
| 56 |
-
data = torch.load(
|
| 57 |
vocab = data["vocab"]
|
| 58 |
|
| 59 |
PAD_IDX = vocab.word2idx["<PAD>"]
|
|
@@ -115,18 +115,12 @@ class Decoder_with_attn(nn.Module):
|
|
| 115 |
self.norm = nn.LayerNorm(512)
|
| 116 |
|
| 117 |
def forward(self, inp, hidden, enc_out):
|
| 118 |
-
"""
|
| 119 |
-
inp: [B, 1] token IDs
|
| 120 |
-
hidden: [2, B, 512] encoder hidden (num_layers, batch, hidden)
|
| 121 |
-
enc_out:[B, T, 512]
|
| 122 |
-
"""
|
| 123 |
e = self.dropout(self.emb(inp))
|
| 124 |
-
|
| 125 |
# attention over encoder outputs
|
| 126 |
energy = self.attn(enc_out)
|
| 127 |
# use top layer hidden state for attention
|
| 128 |
-
attn_scores = torch.bmm(hidden[-1].unsqueeze(1), energy.transpose(1, 2))
|
| 129 |
-
attn_weights = F.softmax(attn_scores.squeeze(1), dim=-1).unsqueeze(1)
|
| 130 |
ctx = torch.bmm(attn_weights, enc_out)
|
| 131 |
|
| 132 |
x = torch.cat((e, ctx), dim=-1)
|
|
@@ -160,10 +154,10 @@ class Model_with_attn(nn.Module):
|
|
| 160 |
|
| 161 |
def forward(self, src, tgt, tf=0.5):
|
| 162 |
enc_out, h = self.encoder(src)
|
| 163 |
-
dec_in = tgt[:, 0]
|
| 164 |
outs = []
|
| 165 |
for t in range(1, tgt.size(1)):
|
| 166 |
-
dec_in = dec_in.unsqueeze(1)
|
| 167 |
out, h = self.decoder(dec_in, h, enc_out)
|
| 168 |
outs.append(out)
|
| 169 |
use_tf = random.random() < tf
|
|
@@ -178,10 +172,10 @@ class Model_no_attn(nn.Module):
|
|
| 178 |
|
| 179 |
def forward(self, src, tgt, tf=0.5):
|
| 180 |
enc_out, h = self.encoder(src)
|
| 181 |
-
dec_in = tgt[:, 0]
|
| 182 |
outs = []
|
| 183 |
for t in range(1, tgt.size(1)):
|
| 184 |
-
dec_in = dec_in.unsqueeze(1)
|
| 185 |
out, h = self.decoder(dec_in, h)
|
| 186 |
outs.append(out)
|
| 187 |
use_tf = random.random() < tf
|
|
@@ -266,7 +260,6 @@ def beam_generate_v2(model, src_tensor, beam=5, max_len=50, alpha=0.7):
|
|
| 266 |
|
| 267 |
# ------------- wrapper to go from user text → reply -------------
|
| 268 |
def generate_reply_attn(user_text: str) -> str:
|
| 269 |
-
# replicate notebook logic: reverse the input sentence
|
| 270 |
user_text_rev = reverse(user_text)
|
| 271 |
tokens = tokenize(user_text_rev)
|
| 272 |
ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
|
|
@@ -277,7 +270,6 @@ def generate_reply_attn(user_text: str) -> str:
|
|
| 277 |
return reply
|
| 278 |
|
| 279 |
def generate_reply_no_attn(user_text: str) -> str:
|
| 280 |
-
# replicate notebook logic: reverse the input sentence
|
| 281 |
user_text_rev = reverse(user_text)
|
| 282 |
tokens = tokenize(user_text_rev)
|
| 283 |
ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
|
|
|
|
| 14 |
|
| 15 |
DEVICE = torch.device("cpu")
|
| 16 |
|
| 17 |
+
VOCAB_FILE = "ubuntu_vocab_only.pt" # To get the Vocab from cache
|
| 18 |
MODEL_FILE_WITH_ATTN = "ubuntu_chatbot_with_attn.pt" # trained model with attn
|
| 19 |
MODEL_FILE_NO_ATTN = "ubuntu_chatbot_no_attn.pt" # trained model without attn
|
| 20 |
|
|
|
|
| 53 |
|
| 54 |
# ------------- load vocab from cache -------------
|
| 55 |
print("Loading vocab...")
|
| 56 |
+
data = torch.load(VOCAB_FILE, map_location="cpu", weights_only=False)
|
| 57 |
vocab = data["vocab"]
|
| 58 |
|
| 59 |
PAD_IDX = vocab.word2idx["<PAD>"]
|
|
|
|
| 115 |
self.norm = nn.LayerNorm(512)
|
| 116 |
|
| 117 |
def forward(self, inp, hidden, enc_out):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
e = self.dropout(self.emb(inp))
|
|
|
|
| 119 |
# attention over encoder outputs
|
| 120 |
energy = self.attn(enc_out)
|
| 121 |
# use top layer hidden state for attention
|
| 122 |
+
attn_scores = torch.bmm(hidden[-1].unsqueeze(1), energy.transpose(1, 2))
|
| 123 |
+
attn_weights = F.softmax(attn_scores.squeeze(1), dim=-1).unsqueeze(1)
|
| 124 |
ctx = torch.bmm(attn_weights, enc_out)
|
| 125 |
|
| 126 |
x = torch.cat((e, ctx), dim=-1)
|
|
|
|
| 154 |
|
| 155 |
def forward(self, src, tgt, tf=0.5):
|
| 156 |
enc_out, h = self.encoder(src)
|
| 157 |
+
dec_in = tgt[:, 0]
|
| 158 |
outs = []
|
| 159 |
for t in range(1, tgt.size(1)):
|
| 160 |
+
dec_in = dec_in.unsqueeze(1)
|
| 161 |
out, h = self.decoder(dec_in, h, enc_out)
|
| 162 |
outs.append(out)
|
| 163 |
use_tf = random.random() < tf
|
|
|
|
| 172 |
|
| 173 |
def forward(self, src, tgt, tf=0.5):
|
| 174 |
enc_out, h = self.encoder(src)
|
| 175 |
+
dec_in = tgt[:, 0]
|
| 176 |
outs = []
|
| 177 |
for t in range(1, tgt.size(1)):
|
| 178 |
+
dec_in = dec_in.unsqueeze(1)
|
| 179 |
out, h = self.decoder(dec_in, h)
|
| 180 |
outs.append(out)
|
| 181 |
use_tf = random.random() < tf
|
|
|
|
| 260 |
|
| 261 |
# ------------- wrapper to go from user text → reply -------------
|
| 262 |
def generate_reply_attn(user_text: str) -> str:
|
|
|
|
| 263 |
user_text_rev = reverse(user_text)
|
| 264 |
tokens = tokenize(user_text_rev)
|
| 265 |
ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
|
|
|
|
| 270 |
return reply
|
| 271 |
|
| 272 |
def generate_reply_no_attn(user_text: str) -> str:
|
|
|
|
| 273 |
user_text_rev = reverse(user_text)
|
| 274 |
tokens = tokenize(user_text_rev)
|
| 275 |
ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
|