File size: 13,221 Bytes
2b150b4
58e7083
 
 
 
 
 
559c43d
58e7083
 
 
559c43d
ac6e07e
58e7083
ac6e07e
58e7083
559c43d
9391674
1d0f992
b5345bd
ac6e07e
58e7083
ac6e07e
 
58e7083
 
ac6e07e
 
 
 
 
 
12fcc4e
58e7083
 
ac6e07e
 
58e7083
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac6e07e
 
9aba057
9391674
9aba057
58e7083
 
 
 
 
 
ac6e07e
db83dcd
 
 
12fcc4e
58e7083
 
 
ac6e07e
 
 
 
 
 
 
 
 
 
 
 
5f9f2be
58e7083
 
ac6e07e
5f9f2be
 
ac6e07e
5f9f2be
 
ac6e07e
5f9f2be
 
 
ac6e07e
5f9f2be
ac6e07e
58e7083
1d0f992
58e7083
 
ac6e07e
 
 
 
 
 
 
 
 
 
58e7083
 
 
 
5f9f2be
ac6e07e
5f9f2be
ac6e07e
9391674
 
5f9f2be
ac6e07e
5f9f2be
 
 
 
ac6e07e
58e7083
1d0f992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58e7083
 
 
1d0f992
58e7083
 
 
9391674
58e7083
 
9391674
58e7083
 
 
 
 
 
b5345bd
 
 
 
 
58e7083
b5345bd
 
9391674
b5345bd
 
9391674
b5345bd
 
 
 
 
 
 
58e7083
b5345bd
1d0f992
 
 
 
58e7083
b5345bd
 
 
 
 
 
 
12fcc4e
ac6e07e
 
 
2caf562
 
58e7083
2caf562
 
58e7083
 
 
ac6e07e
 
2caf562
ac6e07e
58e7083
 
 
ac6e07e
58e7083
ac6e07e
58e7083
ac6e07e
 
2caf562
 
4bf12f4
 
 
2caf562
ac6e07e
2caf562
ac6e07e
 
 
 
58e7083
 
2caf562
ac6e07e
 
2caf562
ac6e07e
 
 
 
2caf562
ac6e07e
 
 
2caf562
 
 
ac6e07e
 
2caf562
ac6e07e
2caf562
58e7083
ac6e07e
 
1d0f992
ac6e07e
 
58e7083
 
1d0f992
58e7083
12fcc4e
58e7083
 
1d0f992
b5345bd
 
 
 
 
 
 
 
ac6e07e
fc85342
882705a
fc85342
 
40b5050
8e78cc4
40b5050
 
 
 
 
 
4f3536e
 
40b5050
4f3536e
 
 
 
 
 
fc85342
40b5050
4f3536e
40b5050
4f3536e
 
 
 
 
 
fc85342
f23e179
b16a6f6
40b5050
8e78cc4
40b5050
75a88f9
4f3536e
4b24f62
40b5050
4f3536e
fc85342
2e7132d
4f3536e
 
 
 
 
 
 
 
 
 
 
a37a484
4f3536e
 
fc85342
2e7132d
4f3536e
 
 
 
 
 
 
 
 
b16a6f6
58e7083
4f3536e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
# -------------------Hugging Face Ubuntu Chatbot Seq2Seq Application code-------------------
import os
import re
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import gradio as gr
import nltk
from nltk.tokenize import word_tokenize
from collections import Counter

# ------------- basic setup -------------
nltk.download(['punkt', 'punkt_tab'], quiet=True)

DEVICE = torch.device("cpu")

VOCAB_FILE = "ubuntu_vocab_only.pt"       # To get the Vocab from cache
MODEL_FILE_WITH_ATTN = "ubuntu_chatbot_with_attn.pt"     # trained model with attn
MODEL_FILE_NO_ATTN = "ubuntu_chatbot_no_attn.pt"  # trained model without attn


# ------------- tokenization + helpers -------------
def tokenize(text: str):
    return word_tokenize(text.lower())


def reverse(sentence: str) -> str:
    """Reverse word order – same trick used in training."""
    return " ".join(sentence.split()[::-1])


# ------------- Vocab class (same as training) -------------
class Vocab:
    def __init__(self):
        self.word2idx = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3}
        self.idx2word = {0: '<PAD>', 1: '<SOS>', 2: '<EOS>', 3: '<UNK>'}

    def __len__(self):
        return len(self.word2idx)

    def build(self, pairs):
        freq = Counter()
        for c, r in pairs:
            freq.update(tokenize(c + " " + r))
        for w, f in freq.most_common(19996):
            if f < 3:
                break
            if w not in self.word2idx:
                idx = len(self.word2idx)
                self.word2idx[w] = idx
                self.idx2word[idx] = w


# ------------- load vocab from cache -------------
print("Loading vocab...")
data = torch.load(VOCAB_FILE, map_location="cpu", weights_only=False)
vocab = data["vocab"]

PAD_IDX = vocab.word2idx["<PAD>"]
SOS_IDX = vocab.word2idx["<SOS>"]
EOS_IDX = vocab.word2idx["<EOS>"]
UNK_IDX = vocab.word2idx["<UNK>"]


print(f"Vocab size loaded: {len(vocab)} words")


# ------------- model definitions (same as notebook) -------------
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(len(vocab), 256, padding_idx=PAD_IDX)
        # bidirectional GRU, 2 layers
        self.gru = nn.GRU(
            input_size=256,
            hidden_size=512,
            num_layers=2,
            batch_first=True,
            dropout=0.3,
            bidirectional=True,
        )
        # projection from 1024 (2 * 512) back to 512
        self.fc = nn.Linear(1024, 512)
        self.norm = nn.LayerNorm(512)   

    def forward(self, x):
        # x: [B, T]
        e = self.emb(x)                 
        out, h = self.gru(e)            

        
        out = self.fc(out)              

   
        h = h.view(2, 2, h.size(1), -1)  
        h = torch.sum(h, dim=1)          

        return out, h                    


class Decoder_with_attn(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(len(vocab), 256, padding_idx=PAD_IDX)
        self.dropout = nn.Dropout(0.3)
        # GRU: input is [emb + context] = 256 + 512
        self.gru = nn.GRU(
            input_size=256 + 512,
            hidden_size=512,
            num_layers=2,
            batch_first=True,
        )
        self.attn = nn.Linear(512, 512)
        self.out = nn.Linear(512, len(vocab))
        self.norm = nn.LayerNorm(512)

    def forward(self, inp, hidden, enc_out):
        e = self.dropout(self.emb(inp))           
        # attention over encoder outputs
        energy = self.attn(enc_out)              
        # use top layer hidden state for attention
        attn_scores = torch.bmm(hidden[-1].unsqueeze(1), energy.transpose(1, 2))  
        attn_weights = F.softmax(attn_scores.squeeze(1), dim=-1).unsqueeze(1)     
        ctx = torch.bmm(attn_weights, enc_out)    

        x = torch.cat((e, ctx), dim=-1)           
        out, hidden = self.gru(x, hidden)         
        out = self.norm(out.squeeze(1))           
        logits = self.out(out)                   
        return logits, hidden

class Decoder_no_attn(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(len(vocab), 256, padding_idx=0)
        self.dropout = nn.Dropout(0.3) # added dropout layer
        self.gru = nn.GRU(256, 512, num_layers=2, batch_first=True)
     
        self.out = nn.Linear(512, len(vocab))
        self.norm = nn.LayerNorm(512)
        
    def forward(self, inp, hidden):
        e = self.dropout(self.emb(inp))
 
        out, hidden = self.gru(e, hidden)
        out = self.norm(out.squeeze(1))
        return self.out(out), hidden

class Model_with_attn(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder_with_attn()

    def forward(self, src, tgt, tf=0.5):
        enc_out, h = self.encoder(src)
        dec_in = tgt[:, 0]                
        outs = []
        for t in range(1, tgt.size(1)):
            dec_in = dec_in.unsqueeze(1)  
            out, h = self.decoder(dec_in, h, enc_out)
            outs.append(out)
            use_tf = random.random() < tf
            dec_in = tgt[:, t] if use_tf else out.argmax(-1).detach()
        return torch.stack(outs, dim=1)

class Model_no_attn(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder_no_attn()

    def forward(self, src, tgt, tf=0.5):
        enc_out, h = self.encoder(src)
        dec_in = tgt[:, 0]                
        outs = []
        for t in range(1, tgt.size(1)):
            dec_in = dec_in.unsqueeze(1)  
            out, h = self.decoder(dec_in, h)
            outs.append(out)
            use_tf = random.random() < tf
            dec_in = tgt[:, t] if use_tf else out.argmax(-1).detach()
        return torch.stack(outs, dim=1)

# ------------- load trained models -------------

# Model with attention
model_with_attn = Model_with_attn().to(DEVICE)
ckpt = torch.load(MODEL_FILE_WITH_ATTN, map_location="cpu")
model_with_attn.load_state_dict(ckpt["model"])
model_with_attn.eval()


# Model without attention
model_no_attn = Model_no_attn().to(DEVICE)
ckpt = torch.load(MODEL_FILE_NO_ATTN, map_location="cpu")
model_no_attn.load_state_dict(ckpt["model"])
model_no_attn.eval()

print("Model and vocab loaded. Chatbot ready to serve ")


# ------------- beam search (beam_generate_v2 from notebook) -------------

def beam_generate_v2(model, src_tensor, beam=5, max_len=50, alpha=0.7):
    """
    Universal beam search for both attention and no-attention models.
    alpha: Length penalty factor. 0.0 = no normalization (prefer short). 1.0 = full normalization (fair to long).
    """
    model.eval()
    with torch.no_grad():
        enc_out, h = model.encoder(src_tensor.to(DEVICE))

        # Beam Structure: (Normalized Score, Raw Score, Hidden, Sequence)
        beams = [(0.0, 0.0, h, [SOS_IDX])]

        for _ in range(max_len):
            candidates = []
            for norm_score, raw_score, hid, seq in beams:
                if seq[-1] == EOS_IDX:
                    candidates.append((norm_score, raw_score, hid, seq))
                    continue

                dec_in = torch.tensor([[seq[-1]]], device=DEVICE)
                # Universal decoder call
                if hasattr(model.decoder, "attn"):
                    out, new_h = model.decoder(dec_in, hid, enc_out)
                else:
                    out, new_h = model.decoder(dec_in, hid)
                probs = F.log_softmax(out, dim=-1).squeeze(0)

                # --- penalise repetition ---
                for prev_token in set(seq):
                    probs[prev_token] -= 2.0

                top = probs.topk(beam + 5)
                for val, idx in zip(top.values, top.indices):
                    token = idx.item()
                    # --- N-gram blocking ---
                    if len(seq) >= 3:
                        new_trigram = tuple(seq[-2:] + [token])
                        existing_trigrams = set(tuple(seq[i:i+3]) for i in range(len(seq)-2))
                        if new_trigram in existing_trigrams:
                            continue
                    new_raw_score = raw_score + val.item()
                    new_seq = seq + [token]
                    # --- length normalization ---
                    length_penalty = ((5 + len(new_seq)) ** alpha) / (6 ** alpha)
                    new_norm_score = new_raw_score / length_penalty
                    candidates.append((new_norm_score, new_raw_score, new_h, new_seq))
            # Sort by NORMALIZED score
            beams = sorted(candidates, key=lambda x: x[0], reverse=True)[:beam]
            # Stop if all top beams have finished
            if all(b[3][-1] == EOS_IDX for b in beams):
                break
        # Return the best sequence
        best_seq = beams[0][3]
        return " ".join([vocab.idx2word.get(i, "<UNK>") for i in best_seq[1:] if i not in [SOS_IDX, EOS_IDX]])


# ------------- wrapper to go from user text → reply -------------
def generate_reply_attn(user_text: str) -> str:
    user_text_rev = reverse(user_text)
    tokens = tokenize(user_text_rev)
    ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
    src = torch.tensor([ids], dtype=torch.long, device=DEVICE)
    reply = beam_generate_v2(model_with_attn,src, beam=5, max_len=50)
    if not reply.strip():
        return "I'm a chatbot trained on Ubuntu Linux support conversations, so I may not understand this question."
    return reply

def generate_reply_no_attn(user_text: str) -> str:
    user_text_rev = reverse(user_text)
    tokens = tokenize(user_text_rev)
    ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX]
    src = torch.tensor([ids], dtype=torch.long, device=DEVICE)
    reply = beam_generate_v2(model_no_attn,src, beam=5, max_len=50)
    if not reply.strip():
        return "I'm a chatbot trained on Ubuntu Linux support conversations, so I may not understand this question."
    return reply


# ---------- Gradio UI --------------------------------

# ---------- Predefined prompts ----------
PREDEFINED = [
    "How can I install my graphics card?",
    "How to update system packages?",
    "How do I check disk usage?",
    "How to install a .deb file?",
    "How do I remove a package with apt?"
]

# ---------- Reply functions for custom Chatbot UI ----------
def reply_no_attn(message, history):
    if not message or not str(message).strip():
        return history + [{"role": "user", "content": message}], ""
    bot_reply = generate_reply_no_attn(message)
    history = history + [
        {"role": "user", "content": message},
        {"role": "assistant", "content": bot_reply}
    ]
    return history, ""

def reply_attn(message, history):
    if not message or not str(message).strip():
        return history + [{"role": "user", "content": message}], ""
    bot_reply = generate_reply_attn(message)
    history = history + [
        {"role": "user", "content": message},
        {"role": "assistant", "content": bot_reply}
    ]
    return history, ""

with gr.Blocks() as demo:
    gr.Markdown("## Ubuntu Chatbot Comparison — No Attention (left) vs Attention (right)")
    gr.Markdown("Use dropdown to quickly fill the chat input. ")

    with gr.Row():
        # Left column: No Attention Model
        with gr.Column(scale=1):
            gr.Markdown("### No Attention Model")
            chatbot_left = gr.Chatbot(label="No Attention Chatbot")
            with gr.Row():
                txt_left = gr.Textbox(show_label=False, placeholder="Type your message here...")
                send_left = gr.Button("Send")
            dd_left = gr.Dropdown(choices=PREDEFINED, label="Quick prompts (left)", interactive=True)
            def set_input_left(selected):
                return selected
            dd_left.change(fn=set_input_left, inputs=dd_left, outputs=txt_left)
            def clear_left():
                return [], ""
            send_left.click(fn=reply_no_attn, inputs=[txt_left, chatbot_left], outputs=[chatbot_left, txt_left])
            chatbot_left.clear(fn=clear_left, inputs=None, outputs=[chatbot_left, txt_left])

        # Right column: With Attention Model
        with gr.Column(scale=1):
            gr.Markdown("### With Attention Model")
            chatbot_right = gr.Chatbot(label="Attention Chatbot")
            with gr.Row():
                txt_right = gr.Textbox(show_label=False, placeholder="Type your message here...")
                send_right = gr.Button("Send")
            dd_right = gr.Dropdown(choices=PREDEFINED, label="Quick prompts (right)", interactive=True)
            def set_input_right(selected):
                return selected
            dd_right.change(fn=set_input_right, inputs=dd_right, outputs=txt_right)
            def clear_right():
                return [], ""
            send_right.click(fn=reply_attn, inputs=[txt_right, chatbot_right], outputs=[chatbot_right, txt_right])
            chatbot_right.clear(fn=clear_right, inputs=None, outputs=[chatbot_right, txt_right])

if __name__ == "__main__":
    demo.launch()