Spaces:
Sleeping
Sleeping
Commit ·
1d0f992
1
Parent(s): 882705a
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,7 +15,7 @@ nltk.download(['punkt', 'punkt_tab'], quiet=True)
|
|
| 15 |
DEVICE = torch.device("cpu")
|
| 16 |
|
| 17 |
CACHE_FILE = "ubuntu_data_cache.pt" # To get the Vocab from cache
|
| 18 |
-
|
| 19 |
|
| 20 |
|
| 21 |
# ------------- tokenization + helpers -------------
|
|
@@ -110,7 +110,7 @@ class Encoder(nn.Module):
|
|
| 110 |
return out, h
|
| 111 |
|
| 112 |
|
| 113 |
-
class
|
| 114 |
def __init__(self):
|
| 115 |
super().__init__()
|
| 116 |
self.emb = nn.Embedding(len(vocab), 256, padding_idx=PAD_IDX)
|
|
@@ -147,12 +147,28 @@ class Decoder(nn.Module):
|
|
| 147 |
logits = self.out(out)
|
| 148 |
return logits, hidden
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
def __init__(self):
|
| 153 |
super().__init__()
|
| 154 |
self.encoder = Encoder()
|
| 155 |
-
self.decoder =
|
| 156 |
|
| 157 |
def forward(self, src, tgt, tf=0.5):
|
| 158 |
enc_out, h = self.encoder(src)
|
|
@@ -168,21 +184,21 @@ class Model(nn.Module):
|
|
| 168 |
|
| 169 |
|
| 170 |
# ------------- load trained model -------------
|
| 171 |
-
if not os.path.exists(
|
| 172 |
raise FileNotFoundError(
|
| 173 |
-
f"{
|
| 174 |
)
|
| 175 |
|
| 176 |
-
|
| 177 |
-
ckpt = torch.load(
|
| 178 |
-
|
| 179 |
-
|
| 180 |
|
| 181 |
print("Model and vocab loaded. Chatbot ready to serve ")
|
| 182 |
|
| 183 |
|
| 184 |
# ------------- beam search (beam_generate_v2 from notebook) -------------
|
| 185 |
-
def beam_generate_v2(src_tensor, beam=5, max_len=50, alpha=0.7):
|
| 186 |
"""
|
| 187 |
src_tensor: [1, T] LongTensor with <SOS> ... <EOS>
|
| 188 |
alpha: length penalty factor
|
|
@@ -258,18 +274,18 @@ def beam_generate_v2(src_tensor, beam=5, max_len=50, alpha=0.7):
|
|
| 258 |
|
| 259 |
|
| 260 |
# ------------- wrapper to go from user text → reply -------------
|
| 261 |
-
def
|
| 262 |
# replicate notebook logic: reverse the input sentence
|
| 263 |
user_text_rev = reverse(user_text)
|
| 264 |
tokens = tokenize(user_text_rev)
|
| 265 |
ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
|
| 266 |
src = torch.tensor([ids], dtype=torch.long, device=DEVICE)
|
| 267 |
-
reply = beam_generate_v2(src, beam=5, max_len=50)
|
| 268 |
if not reply.strip():
|
| 269 |
return "I'm a chatbot trained on Ubuntu Linux support conversations, so I may not understand this question."
|
| 270 |
return reply
|
| 271 |
|
| 272 |
-
def
|
| 273 |
"""
|
| 274 |
Inference using the ATTENTION model.
|
| 275 |
Replace body with your encoder/decoder calls (beam or greedy).
|
|
|
|
| 15 |
DEVICE = torch.device("cpu")
|
| 16 |
|
| 17 |
CACHE_FILE = "ubuntu_data_cache.pt" # To get the Vocab from cache
|
| 18 |
+
MODEL_FILE_WITH_ATTN = "ubuntu_chatbot_with_attn.pt" # trained model with attn
|
| 19 |
|
| 20 |
|
| 21 |
# ------------- tokenization + helpers -------------
|
|
|
|
| 110 |
return out, h
|
| 111 |
|
| 112 |
|
| 113 |
+
class Decoder_with_attn(nn.Module):
|
| 114 |
def __init__(self):
|
| 115 |
super().__init__()
|
| 116 |
self.emb = nn.Embedding(len(vocab), 256, padding_idx=PAD_IDX)
|
|
|
|
| 147 |
logits = self.out(out)
|
| 148 |
return logits, hidden
|
| 149 |
|
| 150 |
+
class Decoder_no_attn(nn.Module):
|
| 151 |
+
def __init__(self):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.emb = nn.Embedding(len(vocab), 256, padding_idx=0)
|
| 154 |
+
self.dropout = nn.Dropout(0.3) # added dropout layer
|
| 155 |
+
self.gru = nn.GRU(256, 512, num_layers=2, batch_first=True)
|
| 156 |
+
|
| 157 |
+
self.out = nn.Linear(512, len(vocab))
|
| 158 |
+
self.norm = nn.LayerNorm(512)
|
| 159 |
+
|
| 160 |
+
def forward(self, inp, hidden):
|
| 161 |
+
e = self.dropout(self.emb(inp))
|
| 162 |
+
|
| 163 |
+
out, hidden = self.gru(e, hidden)
|
| 164 |
+
out = self.norm(out.squeeze(1))
|
| 165 |
+
return self.out(out), hidden
|
| 166 |
+
|
| 167 |
+
class Model_with_attn(nn.Module):
|
| 168 |
def __init__(self):
|
| 169 |
super().__init__()
|
| 170 |
self.encoder = Encoder()
|
| 171 |
+
self.decoder = Decoder_with_attn()
|
| 172 |
|
| 173 |
def forward(self, src, tgt, tf=0.5):
|
| 174 |
enc_out, h = self.encoder(src)
|
|
|
|
| 184 |
|
| 185 |
|
| 186 |
# ------------- load trained model -------------
|
| 187 |
+
if not os.path.exists(MODEL_FILE_WITH_ATTN):
|
| 188 |
raise FileNotFoundError(
|
| 189 |
+
f"{MODEL_FILE_WITH_ATTN} not found in Space. Upload your ubuntu_chatbot_best.pt checkpoint."
|
| 190 |
)
|
| 191 |
|
| 192 |
+
model_with_attn = Model_with_attn().to(DEVICE)
|
| 193 |
+
ckpt = torch.load(MODEL_FILE_WITH_ATTN, map_location="cpu")
|
| 194 |
+
model_with_attn.load_state_dict(ckpt["model"])
|
| 195 |
+
model_with_attn.eval()
|
| 196 |
|
| 197 |
print("Model and vocab loaded. Chatbot ready to serve ")
|
| 198 |
|
| 199 |
|
| 200 |
# ------------- beam search (beam_generate_v2 from notebook) -------------
|
| 201 |
+
def beam_generate_v2(model,src_tensor, beam=5, max_len=50, alpha=0.7):
|
| 202 |
"""
|
| 203 |
src_tensor: [1, T] LongTensor with <SOS> ... <EOS>
|
| 204 |
alpha: length penalty factor
|
|
|
|
| 274 |
|
| 275 |
|
| 276 |
# ------------- wrapper to go from user text → reply -------------
|
| 277 |
+
def generate_reply_attn(user_text: str) -> str:
|
| 278 |
# replicate notebook logic: reverse the input sentence
|
| 279 |
user_text_rev = reverse(user_text)
|
| 280 |
tokens = tokenize(user_text_rev)
|
| 281 |
ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
|
| 282 |
src = torch.tensor([ids], dtype=torch.long, device=DEVICE)
|
| 283 |
+
reply = beam_generate_v2(model_with_attn,src, beam=5, max_len=50)
|
| 284 |
if not reply.strip():
|
| 285 |
return "I'm a chatbot trained on Ubuntu Linux support conversations, so I may not understand this question."
|
| 286 |
return reply
|
| 287 |
|
| 288 |
+
def generate_reply_no_attn(user_text: str) -> str:
|
| 289 |
"""
|
| 290 |
Inference using the ATTENTION model.
|
| 291 |
Replace body with your encoder/decoder calls (beam or greedy).
|