amaresh8053 commited on
Commit
ac6e07e
·
1 Parent(s): 72c41ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -67
app.py CHANGED
@@ -7,25 +7,32 @@ import torch.nn.functional as F
7
  import gradio as gr
8
  import nltk
9
  from nltk.tokenize import word_tokenize
10
- import pandas as pd
11
  from collections import Counter
12
 
13
- # ----------------- basic setup -----------------
14
  nltk.download(['punkt', 'punkt_tab'], quiet=True)
 
15
  DEVICE = torch.device("cpu")
16
 
17
- CACHE_FILE = "ubuntu_data_cache.pt"
18
- MODEL_FILE = "ubuntu_chatbot_best.pt"
 
19
 
20
- # ----------------- tokenizer -----------------
21
- def tokenize(text):
22
  return word_tokenize(text.lower())
23
 
24
- # ----------------- Vocab (same as training) -----------------
 
 
 
 
 
 
25
  class Vocab:
26
  def __init__(self):
27
- self.word2idx = {'<PAD>':0, '<SOS>':1, '<EOS>':2, '<UNK>':3}
28
- self.idx2word = {0:'<PAD>', 1:'<SOS>', 2:'<EOS>', 3:'<UNK>'}
29
 
30
  def __len__(self):
31
  return len(self.word2idx)
@@ -42,12 +49,22 @@ class Vocab:
42
  self.word2idx[w] = idx
43
  self.idx2word[idx] = w
44
 
45
- # ----------------- load vocab from cache -----------------
 
46
  if not os.path.exists(CACHE_FILE):
47
- raise FileNotFoundError(f"Cache file {CACHE_FILE} not found. Please upload ubuntu_data_cache.pt")
 
 
48
 
 
49
  cache = torch.load(CACHE_FILE, map_location="cpu", weights_only=False)
50
- # your file has keys: ['data', 'vocab']
 
 
 
 
 
 
51
  vocab = cache["vocab"]
52
 
53
  # safety: rebuild idx2word if needed
@@ -59,45 +76,77 @@ SOS_IDX = vocab.word2idx["<SOS>"]
59
  EOS_IDX = vocab.word2idx["<EOS>"]
60
  UNK_IDX = vocab.word2idx["<UNK>"]
61
 
62
- # ----------------- Model definitions (same as training) -----------------
 
63
  class Encoder(nn.Module):
64
  def __init__(self):
65
  super().__init__()
66
- self.emb = nn.Embedding(len(vocab), 256, padding_idx=0)
67
- self.gru = nn.GRU(256, 512, num_layers=2, batch_first=True, dropout=0.3)
68
- self.norm = nn.LayerNorm(512)
 
 
 
 
 
 
 
 
 
 
69
 
70
  def forward(self, x):
71
- e = self.emb(x)
72
- out, h = self.gru(e)
73
- return out, self.norm(h[-1]) # enc_out, hidden
 
 
 
 
 
 
 
 
 
 
74
 
75
  class Decoder(nn.Module):
76
  def __init__(self):
77
  super().__init__()
78
- self.emb = nn.Embedding(len(vocab), 256, padding_idx=0)
79
- self.gru = nn.GRU(256 + 512, 512, batch_first=True)
80
- self.attn = nn.Linear(1024, 1) # defined but not used directly (we use dot-product attention)
 
 
 
 
 
 
 
81
  self.out = nn.Linear(512, len(vocab))
82
  self.norm = nn.LayerNorm(512)
83
 
84
  def forward(self, inp, hidden, enc_out):
85
  """
86
- inp: [B, 1]
87
- hidden: [B, 512]
88
  enc_out:[B, T, 512]
89
  """
90
- e = self.emb(inp) # [B, 1, 256]
91
 
92
- # dot-product attention
93
- attn = torch.bmm(hidden.unsqueeze(1), enc_out.transpose(1, 2)) # [B,1,T]
94
- attn = F.softmax(attn.squeeze(1), dim=-1).unsqueeze(1) # [B,1,T]
95
- ctx = torch.bmm(attn, enc_out) # [B,1,512]
 
 
 
 
 
 
 
 
96
 
97
- x = torch.cat((e, ctx), dim=-1) # [B,1,768]
98
- out, hidden = self.gru(x, hidden.unsqueeze(0)) # out:[B,1,512]
99
- out = self.norm(out.squeeze(1)) # [B,512]
100
- return self.out(out), hidden.squeeze(0) # logits:[B,vocab], hidden:[B,512]
101
 
102
  class Model(nn.Module):
103
  def __init__(self):
@@ -107,84 +156,129 @@ class Model(nn.Module):
107
 
108
  def forward(self, src, tgt, tf=0.5):
109
  enc_out, h = self.encoder(src)
110
- dec_in = tgt[:, 0]
111
  outs = []
112
  for t in range(1, tgt.size(1)):
113
- dec_in = dec_in.unsqueeze(1)
114
  out, h = self.decoder(dec_in, h, enc_out)
115
  outs.append(out)
116
  use_tf = random.random() < tf
117
  dec_in = tgt[:, t] if use_tf else out.argmax(-1).detach()
118
  return torch.stack(outs, dim=1)
119
 
120
- # ----------------- load trained weights -----------------
121
- model = Model().to(DEVICE)
122
 
 
123
  if not os.path.exists(MODEL_FILE):
124
- raise FileNotFoundError(f"Model file {MODEL_FILE} not found. Please upload ubuntu_chatbot_best.pt")
 
 
125
 
 
126
  ckpt = torch.load(MODEL_FILE, map_location="cpu")
127
  model.load_state_dict(ckpt["model"])
128
  model.eval()
129
 
130
- # ----------------- beam search generation -----------------
131
- def beam_generate(src_tensor, beam=5, max_len=50):
 
 
 
132
  """
133
- src_tensor: [1, T] LongTensor
134
- returns: decoded string
135
  """
136
  model.eval()
137
  with torch.no_grad():
138
- enc_out, h = model.encoder(src_tensor) # enc_out:[1,T,512], h:[512]
139
- beams = [(torch.tensor([[SOS_IDX]], device=DEVICE), h, 0.0, [SOS_IDX])]
 
 
140
 
141
  for _ in range(max_len):
142
  candidates = []
143
- for inp, hid, score, seq in beams:
 
 
144
  if seq[-1] == EOS_IDX:
145
- candidates.append((score, seq))
146
  continue
147
- out, new_h = model.decoder(inp, hid, enc_out)
148
- probs = F.log_softmax(out, dim=-1).squeeze(0)
149
- top = probs.topk(beam)
 
 
 
 
 
 
 
 
 
 
150
  for val, idx in zip(top.values, top.indices):
151
  token = idx.item()
152
- candidates.append((score + val.item(), seq + [token]))
153
- beams = []
154
- for score, seq in sorted(candidates, reverse=True)[:beam]:
155
- # use last token as next input
156
- beams.append((torch.tensor([[seq[-1]]], device=DEVICE), new_h, score, seq))
157
- if not beams:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  break
159
 
160
- best_seq = sorted(candidates, reverse=True)[0][1]
161
- # convert ids to words, skipping <SOS>=1 and <EOS>=2
 
 
 
 
 
 
 
162
  words = [
163
  vocab.idx2word.get(i, "<UNK>")
164
- for i in best_seq[1:] # skip first <SOS>
165
  if i not in (SOS_IDX, EOS_IDX)
166
  ]
167
  return " ".join(words)
168
 
169
- def generate_reply(user_text):
170
- tokens = tokenize(user_text)
 
 
 
 
171
  ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
172
  src = torch.tensor([ids], dtype=torch.long, device=DEVICE)
173
- reply = beam_generate(src, beam=5, max_len=50)
174
  if not reply.strip():
175
  return "I don't know."
176
  return reply
177
 
178
- # ----------------- Gradio chat UI -----------------
179
- def chat_fn(message, history):
 
180
  reply = generate_reply(message)
181
- history = history + [(message, reply)]
182
- return history, ""
183
 
184
  demo = gr.ChatInterface(
185
- fn=chat_fn,
186
  title="Ubuntu Chatbot (Seq2Seq + GRU + Attention)",
187
- description="Ask questions about Ubuntu or Linux system usage."
188
  )
189
 
190
  if __name__ == "__main__":
 
7
  import gradio as gr
8
  import nltk
9
  from nltk.tokenize import word_tokenize
 
10
  from collections import Counter
11
 
12
+ # ------------- basic setup -------------
13
  nltk.download(['punkt', 'punkt_tab'], quiet=True)
14
+
15
  DEVICE = torch.device("cpu")
16
 
17
+ CACHE_FILE = "ubuntu_data_cache.pt" # from your notebook
18
+ MODEL_FILE = "ubuntu_chatbot_best.pt" # trained model checkpoint
19
+
20
 
21
+ # ------------- tokenization + helpers -------------
22
+ def tokenize(text: str):
23
  return word_tokenize(text.lower())
24
 
25
+
26
+ def reverse(sentence: str) -> str:
27
+ """Reverse word order – same trick used in training."""
28
+ return " ".join(sentence.split()[::-1])
29
+
30
+
31
+ # ------------- Vocab class (must match training) -------------
32
  class Vocab:
33
  def __init__(self):
34
+ self.word2idx = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
35
+ self.idx2word = {0: '<PAD>', 1: '<SOS>', 2: '<EOS>', 3: '<UNK>'}
36
 
37
  def __len__(self):
38
  return len(self.word2idx)
 
49
  self.word2idx[w] = idx
50
  self.idx2word[idx] = w
51
 
52
+
53
+ # ------------- load vocab from cache -------------
54
  if not os.path.exists(CACHE_FILE):
55
+ raise FileNotFoundError(
56
+ f"{CACHE_FILE} not found in Space. Upload the same file you used locally."
57
+ )
58
 
59
+ # cache structure in your notebook: {'data': pairs, 'vocab': vocab}
60
  cache = torch.load(CACHE_FILE, map_location="cpu", weights_only=False)
61
+
62
+ if not isinstance(cache, dict) or "vocab" not in cache:
63
+ raise RuntimeError(
64
+ f"{CACHE_FILE} does not contain a 'vocab' key. "
65
+ f"Found keys: {list(cache.keys()) if isinstance(cache, dict) else type(cache)}"
66
+ )
67
+
68
  vocab = cache["vocab"]
69
 
70
  # safety: rebuild idx2word if needed
 
76
  EOS_IDX = vocab.word2idx["<EOS>"]
77
  UNK_IDX = vocab.word2idx["<UNK>"]
78
 
79
+
80
+ # ------------- model definitions (EXACTLY as in notebook) -------------
81
  class Encoder(nn.Module):
82
  def __init__(self):
83
  super().__init__()
84
+ self.emb = nn.Embedding(len(vocab), 256, padding_idx=PAD_IDX)
85
+ # bidirectional GRU, 2 layers
86
+ self.gru = nn.GRU(
87
+ input_size=256,
88
+ hidden_size=512,
89
+ num_layers=2,
90
+ batch_first=True,
91
+ dropout=0.3,
92
+ bidirectional=True,
93
+ )
94
+ # projection from 1024 (2 * 512) back to 512
95
+ self.fc = nn.Linear(1024, 512)
96
+ self.norm = nn.LayerNorm(512) # defined in notebook (even if not used there)
97
 
98
  def forward(self, x):
99
+ # x: [B, T]
100
+ e = self.emb(x) # [B, T, 256]
101
+ out, h = self.gru(e) # out:[B,T,1024], h:[4,B,512] (2 layers * 2 dirs)
102
+
103
+ # project encoder outputs back to 512
104
+ out = self.fc(out) # [B,T,512]
105
+
106
+ # combine directions in h: reshape [layers*dirs, B, H] -> [layers, dirs, B, H]
107
+ h = h.view(2, 2, h.size(1), -1) # [2,2,B,512]
108
+ h = torch.sum(h, dim=1) # sum over directions -> [2,B,512]
109
+
110
+ return out, h # enc_out:[B,T,512], h:[2,B,512]
111
+
112
 
113
  class Decoder(nn.Module):
114
  def __init__(self):
115
  super().__init__()
116
+ self.emb = nn.Embedding(len(vocab), 256, padding_idx=PAD_IDX)
117
+ self.dropout = nn.Dropout(0.3)
118
+ # GRU: input is [emb + context] = 256 + 512
119
+ self.gru = nn.GRU(
120
+ input_size=256 + 512,
121
+ hidden_size=512,
122
+ num_layers=2,
123
+ batch_first=True,
124
+ )
125
+ self.attn = nn.Linear(512, 512)
126
  self.out = nn.Linear(512, len(vocab))
127
  self.norm = nn.LayerNorm(512)
128
 
129
  def forward(self, inp, hidden, enc_out):
130
  """
131
+ inp: [B, 1] token IDs
132
+ hidden: [2, B, 512] encoder hidden (num_layers, batch, hidden)
133
  enc_out:[B, T, 512]
134
  """
135
+ e = self.dropout(self.emb(inp)) # [B,1,256]
136
 
137
+ # attention over encoder outputs
138
+ energy = self.attn(enc_out) # [B,T,512]
139
+ # use top layer hidden state for attention
140
+ attn_scores = torch.bmm(hidden[-1].unsqueeze(1), energy.transpose(1, 2)) # [B,1,T]
141
+ attn_weights = F.softmax(attn_scores.squeeze(1), dim=-1).unsqueeze(1) # [B,1,T]
142
+ ctx = torch.bmm(attn_weights, enc_out) # [B,1,512]
143
+
144
+ x = torch.cat((e, ctx), dim=-1) # [B,1,768]
145
+ out, hidden = self.gru(x, hidden) # out:[B,1,512], hidden:[2,B,512]
146
+ out = self.norm(out.squeeze(1)) # [B,512]
147
+ logits = self.out(out) # [B,vocab]
148
+ return logits, hidden
149
 
 
 
 
 
150
 
151
  class Model(nn.Module):
152
  def __init__(self):
 
156
 
157
  def forward(self, src, tgt, tf=0.5):
158
  enc_out, h = self.encoder(src)
159
+ dec_in = tgt[:, 0] # <SOS>
160
  outs = []
161
  for t in range(1, tgt.size(1)):
162
+ dec_in = dec_in.unsqueeze(1) # [B,1]
163
  out, h = self.decoder(dec_in, h, enc_out)
164
  outs.append(out)
165
  use_tf = random.random() < tf
166
  dec_in = tgt[:, t] if use_tf else out.argmax(-1).detach()
167
  return torch.stack(outs, dim=1)
168
 
 
 
169
 
170
+ # ------------- load trained model -------------
171
  if not os.path.exists(MODEL_FILE):
172
+ raise FileNotFoundError(
173
+ f"{MODEL_FILE} not found in Space. Upload your ubuntu_chatbot_best.pt checkpoint."
174
+ )
175
 
176
+ model = Model().to(DEVICE)
177
  ckpt = torch.load(MODEL_FILE, map_location="cpu")
178
  model.load_state_dict(ckpt["model"])
179
  model.eval()
180
 
181
+ print("✅ Model and vocab loaded. Chatbot ready to serve 🚀")
182
+
183
+
184
+ # ------------- beam search (beam_generate_v2 from notebook) -------------
185
+ def beam_generate_v2(src_tensor, beam=5, max_len=50, alpha=0.7):
186
  """
187
+ src_tensor: [1, T] LongTensor with <SOS> ... <EOS>
188
+ alpha: length penalty factor
189
  """
190
  model.eval()
191
  with torch.no_grad():
192
+ enc_out, h = model.encoder(src_tensor.to(DEVICE))
193
+
194
+ # Beam entry: (normalized_score, raw_score, hidden, sequence_ids)
195
+ beams = [(0.0, 0.0, h, [SOS_IDX])]
196
 
197
  for _ in range(max_len):
198
  candidates = []
199
+
200
+ for norm_score, raw_score, hid, seq in beams:
201
+ # if last token is EOS -> keep as-is
202
  if seq[-1] == EOS_IDX:
203
+ candidates.append((norm_score, raw_score, hid, seq))
204
  continue
205
+
206
+ # decoder step: input is last token
207
+ dec_in = torch.tensor([[seq[-1]]], device=DEVICE)
208
+ out, new_h = model.decoder(dec_in, hid, enc_out)
209
+ probs = F.log_softmax(out, dim=-1).squeeze(0) # [vocab]
210
+
211
+ # penalty for repetition
212
+ for prev_token in set(seq):
213
+ probs[prev_token] -= 2.0
214
+
215
+ # take more candidates than beam, then filter
216
+ top = probs.topk(beam + 5)
217
+
218
  for val, idx in zip(top.values, top.indices):
219
  token = idx.item()
220
+
221
+ # 3-gram blocking
222
+ if len(seq) >= 3:
223
+ new_trigram = tuple(seq[-2:] + [token])
224
+ existing_trigrams = set(
225
+ tuple(seq[i:i+3]) for i in range(len(seq) - 2)
226
+ )
227
+ if new_trigram in existing_trigrams:
228
+ continue
229
+
230
+ new_raw_score = raw_score + val.item()
231
+ new_seq = seq + [token]
232
+
233
+ # length normalization
234
+ length_penalty = ((5 + len(new_seq)) ** alpha) / (6 ** alpha)
235
+ new_norm_score = new_raw_score / length_penalty
236
+
237
+ candidates.append((new_norm_score, new_raw_score, new_h, new_seq))
238
+
239
+ # keep top beam by normalized score
240
+ if not candidates:
241
  break
242
 
243
+ candidates = sorted(candidates, key=lambda x: x[0], reverse=True)
244
+ beams = candidates[:beam]
245
+
246
+ # early stop if all beams ended with EOS
247
+ if all(b[3][-1] == EOS_IDX for b in beams):
248
+ break
249
+
250
+ best_seq = beams[0][3]
251
+ # convert ids to words (skip SOS/EOS)
252
  words = [
253
  vocab.idx2word.get(i, "<UNK>")
254
+ for i in best_seq[1:]
255
  if i not in (SOS_IDX, EOS_IDX)
256
  ]
257
  return " ".join(words)
258
 
259
+
260
+ # ------------- wrapper to go from user text → reply -------------
261
+ def generate_reply(user_text: str) -> str:
262
+ # replicate notebook logic: reverse the input sentence
263
+ user_text_rev = reverse(user_text)
264
+ tokens = tokenize(user_text_rev)
265
  ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
266
  src = torch.tensor([ids], dtype=torch.long, device=DEVICE)
267
+ reply = beam_generate_v2(src, beam=5, max_len=50)
268
  if not reply.strip():
269
  return "I don't know."
270
  return reply
271
 
272
+
273
+ # ------------- Gradio ChatInterface -------------
274
+ def respond(message, history):
275
  reply = generate_reply(message)
276
+ return reply
 
277
 
278
  demo = gr.ChatInterface(
279
+ fn=respond,
280
  title="Ubuntu Chatbot (Seq2Seq + GRU + Attention)",
281
+ description="A generative chatbot trained on Ubuntu dialogue pairs (seq2seq with attention)."
282
  )
283
 
284
  if __name__ == "__main__":