amaresh8053 commited on
Commit
1d0f992
·
1 Parent(s): 882705a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -15
app.py CHANGED
@@ -15,7 +15,7 @@ nltk.download(['punkt', 'punkt_tab'], quiet=True)
15
  DEVICE = torch.device("cpu")
16
 
17
  CACHE_FILE = "ubuntu_data_cache.pt" # To get the Vocab from cache
18
- MODEL_FILE = "ubuntu_chatbot_best.pt" # trained model
19
 
20
 
21
  # ------------- tokenization + helpers -------------
@@ -110,7 +110,7 @@ class Encoder(nn.Module):
110
  return out, h
111
 
112
 
113
- class Decoder(nn.Module):
114
  def __init__(self):
115
  super().__init__()
116
  self.emb = nn.Embedding(len(vocab), 256, padding_idx=PAD_IDX)
@@ -147,12 +147,28 @@ class Decoder(nn.Module):
147
  logits = self.out(out)
148
  return logits, hidden
149
 
150
-
151
- class Model(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def __init__(self):
153
  super().__init__()
154
  self.encoder = Encoder()
155
- self.decoder = Decoder()
156
 
157
  def forward(self, src, tgt, tf=0.5):
158
  enc_out, h = self.encoder(src)
@@ -168,21 +184,21 @@ class Model(nn.Module):
168
 
169
 
170
  # ------------- load trained model -------------
171
- if not os.path.exists(MODEL_FILE):
172
  raise FileNotFoundError(
173
- f"{MODEL_FILE} not found in Space. Upload your ubuntu_chatbot_best.pt checkpoint."
174
  )
175
 
176
- model = Model().to(DEVICE)
177
- ckpt = torch.load(MODEL_FILE, map_location="cpu")
178
- model.load_state_dict(ckpt["model"])
179
- model.eval()
180
 
181
  print("Model and vocab loaded. Chatbot ready to serve ")
182
 
183
 
184
  # ------------- beam search (beam_generate_v2 from notebook) -------------
185
- def beam_generate_v2(src_tensor, beam=5, max_len=50, alpha=0.7):
186
  """
187
  src_tensor: [1, T] LongTensor with <SOS> ... <EOS>
188
  alpha: length penalty factor
@@ -258,18 +274,18 @@ def beam_generate_v2(src_tensor, beam=5, max_len=50, alpha=0.7):
258
 
259
 
260
  # ------------- wrapper to go from user text → reply -------------
261
- def generate_reply_no_attn(user_text: str) -> str:
262
  # replicate notebook logic: reverse the input sentence
263
  user_text_rev = reverse(user_text)
264
  tokens = tokenize(user_text_rev)
265
  ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
266
  src = torch.tensor([ids], dtype=torch.long, device=DEVICE)
267
- reply = beam_generate_v2(src, beam=5, max_len=50)
268
  if not reply.strip():
269
  return "I'm a chatbot trained on Ubuntu Linux support conversations, so I may not understand this question."
270
  return reply
271
 
272
- def generate_reply_attn(user_text: str) -> str:
273
  """
274
  Inference using the ATTENTION model.
275
  Replace body with your encoder/decoder calls (beam or greedy).
 
15
  DEVICE = torch.device("cpu")
16
 
17
  CACHE_FILE = "ubuntu_data_cache.pt" # To get the Vocab from cache
18
+ MODEL_FILE_WITH_ATTN = "ubuntu_chatbot_with_attn.pt" # trained model with attn
19
 
20
 
21
  # ------------- tokenization + helpers -------------
 
110
  return out, h
111
 
112
 
113
+ class Decoder_with_attn(nn.Module):
114
  def __init__(self):
115
  super().__init__()
116
  self.emb = nn.Embedding(len(vocab), 256, padding_idx=PAD_IDX)
 
147
  logits = self.out(out)
148
  return logits, hidden
149
 
150
+ class Decoder_no_attn(nn.Module):
151
+ def __init__(self):
152
+ super().__init__()
153
+ self.emb = nn.Embedding(len(vocab), 256, padding_idx=0)
154
+ self.dropout = nn.Dropout(0.3) # added dropout layer
155
+ self.gru = nn.GRU(256, 512, num_layers=2, batch_first=True)
156
+
157
+ self.out = nn.Linear(512, len(vocab))
158
+ self.norm = nn.LayerNorm(512)
159
+
160
+ def forward(self, inp, hidden):
161
+ e = self.dropout(self.emb(inp))
162
+
163
+ out, hidden = self.gru(e, hidden)
164
+ out = self.norm(out.squeeze(1))
165
+ return self.out(out), hidden
166
+
167
+ class Model_with_attn(nn.Module):
168
  def __init__(self):
169
  super().__init__()
170
  self.encoder = Encoder()
171
+ self.decoder = Decoder_with_attn()
172
 
173
  def forward(self, src, tgt, tf=0.5):
174
  enc_out, h = self.encoder(src)
 
184
 
185
 
186
  # ------------- load trained model -------------
187
+ if not os.path.exists(MODEL_FILE_WITH_ATTN):
188
  raise FileNotFoundError(
189
+ f"{MODEL_FILE_WITH_ATTN} not found in Space. Upload your ubuntu_chatbot_best.pt checkpoint."
190
  )
191
 
192
+ model_with_attn = Model_with_attn().to(DEVICE)
193
+ ckpt = torch.load(MODEL_FILE_WITH_ATTN, map_location="cpu")
194
+ model_with_attn.load_state_dict(ckpt["model"])
195
+ model_with_attn.eval()
196
 
197
  print("Model and vocab loaded. Chatbot ready to serve ")
198
 
199
 
200
  # ------------- beam search (beam_generate_v2 from notebook) -------------
201
+ def beam_generate_v2(model,src_tensor, beam=5, max_len=50, alpha=0.7):
202
  """
203
  src_tensor: [1, T] LongTensor with <SOS> ... <EOS>
204
  alpha: length penalty factor
 
274
 
275
 
276
  # ------------- wrapper to go from user text → reply -------------
277
+ def generate_reply_attn(user_text: str) -> str:
278
  # replicate notebook logic: reverse the input sentence
279
  user_text_rev = reverse(user_text)
280
  tokens = tokenize(user_text_rev)
281
  ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
282
  src = torch.tensor([ids], dtype=torch.long, device=DEVICE)
283
+ reply = beam_generate_v2(model_with_attn,src, beam=5, max_len=50)
284
  if not reply.strip():
285
  return "I'm a chatbot trained on Ubuntu Linux support conversations, so I may not understand this question."
286
  return reply
287
 
288
+ def generate_reply_no_attn(user_text: str) -> str:
289
  """
290
  Inference using the ATTENTION model.
291
  Replace body with your encoder/decoder calls (beam or greedy).