Spaces:
Sleeping
Sleeping
| # -------------------Hugging Face Ubuntu Chatbot Seq2Seq Application code------------------- | |
| import os | |
| import re | |
| import random | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import gradio as gr | |
| import nltk | |
| from nltk.tokenize import word_tokenize | |
| from collections import Counter | |
| # ------------- basic setup ------------- | |
| nltk.download(['punkt', 'punkt_tab'], quiet=True) | |
| DEVICE = torch.device("cpu") | |
| VOCAB_FILE = "ubuntu_vocab_only.pt" # To get the Vocab from cache | |
| MODEL_FILE_WITH_ATTN = "ubuntu_chatbot_with_attn.pt" # trained model with attn | |
| MODEL_FILE_NO_ATTN = "ubuntu_chatbot_no_attn.pt" # trained model without attn | |
| # ------------- tokenization + helpers ------------- | |
| def tokenize(text: str): | |
| return word_tokenize(text.lower()) | |
| def reverse(sentence: str) -> str: | |
| """Reverse word order – same trick used in training.""" | |
| return " ".join(sentence.split()[::-1]) | |
| # ------------- Vocab class (same as training) ------------- | |
| class Vocab: | |
| def __init__(self): | |
| self.word2idx = {'<PAD>': 0, '<SOS>': 1, '<EOS>': 2, '<UNK>': 3} | |
| self.idx2word = {0: '<PAD>', 1: '<SOS>', 2: '<EOS>', 3: '<UNK>'} | |
| def __len__(self): | |
| return len(self.word2idx) | |
| def build(self, pairs): | |
| freq = Counter() | |
| for c, r in pairs: | |
| freq.update(tokenize(c + " " + r)) | |
| for w, f in freq.most_common(19996): | |
| if f < 3: | |
| break | |
| if w not in self.word2idx: | |
| idx = len(self.word2idx) | |
| self.word2idx[w] = idx | |
| self.idx2word[idx] = w | |
| # ------------- load vocab from cache ------------- | |
| print("Loading vocab...") | |
| data = torch.load(VOCAB_FILE, map_location="cpu", weights_only=False) | |
| vocab = data["vocab"] | |
| PAD_IDX = vocab.word2idx["<PAD>"] | |
| SOS_IDX = vocab.word2idx["<SOS>"] | |
| EOS_IDX = vocab.word2idx["<EOS>"] | |
| UNK_IDX = vocab.word2idx["<UNK>"] | |
| print(f"Vocab size loaded: {len(vocab)} words") | |
| # ------------- model definitions (same as notebook) ------------- | |
| class Encoder(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.emb = nn.Embedding(len(vocab), 256, padding_idx=PAD_IDX) | |
| # bidirectional GRU, 2 layers | |
| self.gru = nn.GRU( | |
| input_size=256, | |
| hidden_size=512, | |
| num_layers=2, | |
| batch_first=True, | |
| dropout=0.3, | |
| bidirectional=True, | |
| ) | |
| # projection from 1024 (2 * 512) back to 512 | |
| self.fc = nn.Linear(1024, 512) | |
| self.norm = nn.LayerNorm(512) | |
| def forward(self, x): | |
| # x: [B, T] | |
| e = self.emb(x) | |
| out, h = self.gru(e) | |
| out = self.fc(out) | |
| h = h.view(2, 2, h.size(1), -1) | |
| h = torch.sum(h, dim=1) | |
| return out, h | |
| class Decoder_with_attn(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.emb = nn.Embedding(len(vocab), 256, padding_idx=PAD_IDX) | |
| self.dropout = nn.Dropout(0.3) | |
| # GRU: input is [emb + context] = 256 + 512 | |
| self.gru = nn.GRU( | |
| input_size=256 + 512, | |
| hidden_size=512, | |
| num_layers=2, | |
| batch_first=True, | |
| ) | |
| self.attn = nn.Linear(512, 512) | |
| self.out = nn.Linear(512, len(vocab)) | |
| self.norm = nn.LayerNorm(512) | |
| def forward(self, inp, hidden, enc_out): | |
| e = self.dropout(self.emb(inp)) | |
| # attention over encoder outputs | |
| energy = self.attn(enc_out) | |
| # use top layer hidden state for attention | |
| attn_scores = torch.bmm(hidden[-1].unsqueeze(1), energy.transpose(1, 2)) | |
| attn_weights = F.softmax(attn_scores.squeeze(1), dim=-1).unsqueeze(1) | |
| ctx = torch.bmm(attn_weights, enc_out) | |
| x = torch.cat((e, ctx), dim=-1) | |
| out, hidden = self.gru(x, hidden) | |
| out = self.norm(out.squeeze(1)) | |
| logits = self.out(out) | |
| return logits, hidden | |
| class Decoder_no_attn(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.emb = nn.Embedding(len(vocab), 256, padding_idx=0) | |
| self.dropout = nn.Dropout(0.3) # added dropout layer | |
| self.gru = nn.GRU(256, 512, num_layers=2, batch_first=True) | |
| self.out = nn.Linear(512, len(vocab)) | |
| self.norm = nn.LayerNorm(512) | |
| def forward(self, inp, hidden): | |
| e = self.dropout(self.emb(inp)) | |
| out, hidden = self.gru(e, hidden) | |
| out = self.norm(out.squeeze(1)) | |
| return self.out(out), hidden | |
| class Model_with_attn(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.encoder = Encoder() | |
| self.decoder = Decoder_with_attn() | |
| def forward(self, src, tgt, tf=0.5): | |
| enc_out, h = self.encoder(src) | |
| dec_in = tgt[:, 0] | |
| outs = [] | |
| for t in range(1, tgt.size(1)): | |
| dec_in = dec_in.unsqueeze(1) | |
| out, h = self.decoder(dec_in, h, enc_out) | |
| outs.append(out) | |
| use_tf = random.random() < tf | |
| dec_in = tgt[:, t] if use_tf else out.argmax(-1).detach() | |
| return torch.stack(outs, dim=1) | |
| class Model_no_attn(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.encoder = Encoder() | |
| self.decoder = Decoder_no_attn() | |
| def forward(self, src, tgt, tf=0.5): | |
| enc_out, h = self.encoder(src) | |
| dec_in = tgt[:, 0] | |
| outs = [] | |
| for t in range(1, tgt.size(1)): | |
| dec_in = dec_in.unsqueeze(1) | |
| out, h = self.decoder(dec_in, h) | |
| outs.append(out) | |
| use_tf = random.random() < tf | |
| dec_in = tgt[:, t] if use_tf else out.argmax(-1).detach() | |
| return torch.stack(outs, dim=1) | |
| # ------------- load trained models ------------- | |
| # Model with attention | |
| model_with_attn = Model_with_attn().to(DEVICE) | |
| ckpt = torch.load(MODEL_FILE_WITH_ATTN, map_location="cpu") | |
| model_with_attn.load_state_dict(ckpt["model"]) | |
| model_with_attn.eval() | |
| # Model without attention | |
| model_no_attn = Model_no_attn().to(DEVICE) | |
| ckpt = torch.load(MODEL_FILE_NO_ATTN, map_location="cpu") | |
| model_no_attn.load_state_dict(ckpt["model"]) | |
| model_no_attn.eval() | |
| print("Model and vocab loaded. Chatbot ready to serve ") | |
| # ------------- beam search (beam_generate_v2 from notebook) ------------- | |
| def beam_generate_v2(model, src_tensor, beam=5, max_len=50, alpha=0.7): | |
| """ | |
| Universal beam search for both attention and no-attention models. | |
| alpha: Length penalty factor. 0.0 = no normalization (prefer short). 1.0 = full normalization (fair to long). | |
| """ | |
| model.eval() | |
| with torch.no_grad(): | |
| enc_out, h = model.encoder(src_tensor.to(DEVICE)) | |
| # Beam Structure: (Normalized Score, Raw Score, Hidden, Sequence) | |
| beams = [(0.0, 0.0, h, [SOS_IDX])] | |
| for _ in range(max_len): | |
| candidates = [] | |
| for norm_score, raw_score, hid, seq in beams: | |
| if seq[-1] == EOS_IDX: | |
| candidates.append((norm_score, raw_score, hid, seq)) | |
| continue | |
| dec_in = torch.tensor([[seq[-1]]], device=DEVICE) | |
| # Universal decoder call | |
| if hasattr(model.decoder, "attn"): | |
| out, new_h = model.decoder(dec_in, hid, enc_out) | |
| else: | |
| out, new_h = model.decoder(dec_in, hid) | |
| probs = F.log_softmax(out, dim=-1).squeeze(0) | |
| # --- penalise repetition --- | |
| for prev_token in set(seq): | |
| probs[prev_token] -= 2.0 | |
| top = probs.topk(beam + 5) | |
| for val, idx in zip(top.values, top.indices): | |
| token = idx.item() | |
| # --- N-gram blocking --- | |
| if len(seq) >= 3: | |
| new_trigram = tuple(seq[-2:] + [token]) | |
| existing_trigrams = set(tuple(seq[i:i+3]) for i in range(len(seq)-2)) | |
| if new_trigram in existing_trigrams: | |
| continue | |
| new_raw_score = raw_score + val.item() | |
| new_seq = seq + [token] | |
| # --- length normalization --- | |
| length_penalty = ((5 + len(new_seq)) ** alpha) / (6 ** alpha) | |
| new_norm_score = new_raw_score / length_penalty | |
| candidates.append((new_norm_score, new_raw_score, new_h, new_seq)) | |
| # Sort by NORMALIZED score | |
| beams = sorted(candidates, key=lambda x: x[0], reverse=True)[:beam] | |
| # Stop if all top beams have finished | |
| if all(b[3][-1] == EOS_IDX for b in beams): | |
| break | |
| # Return the best sequence | |
| best_seq = beams[0][3] | |
| return " ".join([vocab.idx2word.get(i, "<UNK>") for i in best_seq[1:] if i not in [SOS_IDX, EOS_IDX]]) | |
| # ------------- wrapper to go from user text → reply ------------- | |
| def generate_reply_attn(user_text: str) -> str: | |
| user_text_rev = reverse(user_text) | |
| tokens = tokenize(user_text_rev) | |
| ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX] | |
| src = torch.tensor([ids], dtype=torch.long, device=DEVICE) | |
| reply = beam_generate_v2(model_with_attn,src, beam=5, max_len=50) | |
| if not reply.strip(): | |
| return "I'm a chatbot trained on Ubuntu Linux support conversations, so I may not understand this question." | |
| return reply | |
| def generate_reply_no_attn(user_text: str) -> str: | |
| user_text_rev = reverse(user_text) | |
| tokens = tokenize(user_text_rev) | |
| ids = [SOS_IDX] + [vocab.word2idx.get(w, UNK_IDX) for w in tokens] + [EOS_IDX] | |
| src = torch.tensor([ids], dtype=torch.long, device=DEVICE) | |
| reply = beam_generate_v2(model_no_attn,src, beam=5, max_len=50) | |
| if not reply.strip(): | |
| return "I'm a chatbot trained on Ubuntu Linux support conversations, so I may not understand this question." | |
| return reply | |
| # ---------- Gradio UI -------------------------------- | |
| # ---------- Predefined prompts ---------- | |
| PREDEFINED = [ | |
| "How can I install my graphics card?", | |
| "How to update system packages?", | |
| "How do I check disk usage?", | |
| "How to install a .deb file?", | |
| "How do I remove a package with apt?" | |
| ] | |
| # ---------- Reply functions for custom Chatbot UI ---------- | |
| def reply_no_attn(message, history): | |
| if not message or not str(message).strip(): | |
| return history + [{"role": "user", "content": message}], "" | |
| bot_reply = generate_reply_no_attn(message) | |
| history = history + [ | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": bot_reply} | |
| ] | |
| return history, "" | |
| def reply_attn(message, history): | |
| if not message or not str(message).strip(): | |
| return history + [{"role": "user", "content": message}], "" | |
| bot_reply = generate_reply_attn(message) | |
| history = history + [ | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": bot_reply} | |
| ] | |
| return history, "" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Ubuntu Chatbot Comparison — No Attention (left) vs Attention (right)") | |
| gr.Markdown("Use dropdown to quickly fill the chat input. ") | |
| with gr.Row(): | |
| # Left column: No Attention Model | |
| with gr.Column(scale=1): | |
| gr.Markdown("### No Attention Model") | |
| chatbot_left = gr.Chatbot(label="No Attention Chatbot") | |
| with gr.Row(): | |
| txt_left = gr.Textbox(show_label=False, placeholder="Type your message here...") | |
| send_left = gr.Button("Send") | |
| dd_left = gr.Dropdown(choices=PREDEFINED, label="Quick prompts (left)", interactive=True) | |
| def set_input_left(selected): | |
| return selected | |
| dd_left.change(fn=set_input_left, inputs=dd_left, outputs=txt_left) | |
| def clear_left(): | |
| return [], "" | |
| send_left.click(fn=reply_no_attn, inputs=[txt_left, chatbot_left], outputs=[chatbot_left, txt_left]) | |
| chatbot_left.clear(fn=clear_left, inputs=None, outputs=[chatbot_left, txt_left]) | |
| # Right column: With Attention Model | |
| with gr.Column(scale=1): | |
| gr.Markdown("### With Attention Model") | |
| chatbot_right = gr.Chatbot(label="Attention Chatbot") | |
| with gr.Row(): | |
| txt_right = gr.Textbox(show_label=False, placeholder="Type your message here...") | |
| send_right = gr.Button("Send") | |
| dd_right = gr.Dropdown(choices=PREDEFINED, label="Quick prompts (right)", interactive=True) | |
| def set_input_right(selected): | |
| return selected | |
| dd_right.change(fn=set_input_right, inputs=dd_right, outputs=txt_right) | |
| def clear_right(): | |
| return [], "" | |
| send_right.click(fn=reply_attn, inputs=[txt_right, chatbot_right], outputs=[chatbot_right, txt_right]) | |
| chatbot_right.clear(fn=clear_right, inputs=None, outputs=[chatbot_right, txt_right]) | |
| if __name__ == "__main__": | |
| demo.launch() |