Spaces:
Sleeping
Sleeping
| import os | |
| import math | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| from collections import Counter | |
| import streamlit as st | |
| # ========================================== | |
| # FORCE CPU FOR DEPLOYMENT HUGGING FACE SPACES | |
| # ========================================== | |
| device = torch.device("cpu") | |
| # ========================================== | |
| # 1. TOKENIZER & VOCABULARY BUILDER | |
| # ========================================== | |
| class Vocabulary: | |
| def __init__(self, pad_token="<PAD>", sos_token="<SOS>", eos_token="<EOS>", unk_token="<UNK>"): | |
| self.pad_token = pad_token | |
| self.sos_token = sos_token | |
| self.eos_token = eos_token | |
| self.unk_token = unk_token | |
| self.w2i = {pad_token: 0, sos_token: 1, eos_token: 2, unk_token: 3} | |
| self.i2w = {0: pad_token, 1: sos_token, 2: eos_token, 3: unk_token} | |
| self.vocab_size = 4 | |
| def build_vocab(self, sentences): | |
| words = [] | |
| for sentence in sentences: | |
| words.extend(str(sentence).lower().split()) | |
| counter = Counter(words) | |
| for word, _ in counter.items(): | |
| if word not in self.w2i: | |
| self.w2i[word] = self.vocab_size | |
| self.i2w[self.vocab_size] = word | |
| self.vocab_size += 1 | |
| def numericalize(self, sentence): | |
| tokens = str(sentence).lower().split() | |
| return [self.w2i.get(token, self.w2i[self.unk_token]) for token in tokens] | |
| # ========================================== | |
| # 2. PYTORCH DATASET & COLLATOR | |
| # ========================================== | |
| class TranslationDataset(Dataset): | |
| def __init__(self, df, src_vocab, trg_vocab): | |
| self.df = df | |
| self.src_vocab = src_vocab | |
| self.trg_vocab = trg_vocab | |
| def __len__(self): | |
| return len(self.df) | |
| def __getitem__(self, idx): | |
| src_sent = self.df.iloc[idx]['english'] | |
| trg_sent = self.df.iloc[idx]['spanish'] | |
| src_indices = self.src_vocab.numericalize(src_sent) | |
| trg_indices = [self.trg_vocab.w2i["<SOS>"]] + self.trg_vocab.numericalize(trg_sent) + [self.trg_vocab.w2i["<EOS>"]] | |
| return torch.tensor(src_indices), torch.tensor(trg_indices) | |
| def pad_collate_fn(batch): | |
| src_batch, trg_batch = zip(*batch) | |
| src_padded = nn.utils.rnn.pad_sequence(src_batch, batch_first=True, padding_value=0) | |
| trg_padded = nn.utils.rnn.pad_sequence(trg_batch, batch_first=True, padding_value=0) | |
| return src_padded, trg_padded | |
| # ========================================== | |
| # 3. TRANSFORMER MODEL ARCHITECTURE | |
| # ========================================== | |
| class PositionalEncoding(nn.Module): | |
| def __init__(self, d_model, max_len=100): | |
| super().__init__() | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| self.register_buffer('pe', pe.unsqueeze(0)) | |
| def forward(self, x): | |
| return x + self.pe[:, :x.size(1)] | |
| class PyTorchTransformer(nn.Module): | |
| def __init__(self, src_vocab_size, trg_vocab_size, d_model=128, nhead=4, num_layers=2, dim_feedforward=256, dropout=0.1): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.src_embedding = nn.Embedding(src_vocab_size, d_model) | |
| self.trg_embedding = nn.Embedding(trg_vocab_size, d_model) | |
| self.pos_encoder = PositionalEncoding(d_model) | |
| self.transformer = nn.Transformer( | |
| d_model=d_model, | |
| nhead=nhead, | |
| num_encoder_layers=num_layers, | |
| num_decoder_layers=num_layers, | |
| dim_feedforward=dim_feedforward, | |
| dropout=dropout, | |
| batch_first=True | |
| ) | |
| self.fc_out = nn.Linear(d_model, trg_vocab_size) | |
| def generate_square_subsequent_mask(self, sz, device): | |
| mask = (torch.triu(torch.ones(sz, sz, device=device)) == 1).transpose(0, 1) | |
| mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) | |
| return mask | |
| def forward(self, src, trg): | |
| src_seq_len = src.size(1) | |
| trg_seq_len = trg.size(1) | |
| src_padding_mask = (src == 0) | |
| trg_padding_mask = (trg == 0) | |
| trg_mask = self.generate_square_subsequent_mask(trg_seq_len, src.device) | |
| src_emb = self.pos_encoder(self.src_embedding(src) * math.sqrt(self.d_model)) | |
| trg_emb = self.pos_encoder(self.trg_embedding(trg) * math.sqrt(self.d_model)) | |
| out = self.transformer( | |
| src_emb, trg_emb, | |
| tgt_mask=trg_mask, | |
| src_key_padding_mask=src_padding_mask, | |
| tgt_key_padding_mask=trg_padding_mask, | |
| memory_key_padding_mask=src_padding_mask | |
| ) | |
| return self.fc_out(out) | |
| # ========================================== | |
| # 4. STREAMLIT APP LAYOUT & LOGIC | |
| # ========================================== | |
| st.set_page_config(page_title="Transformer English to Spanish", layout="centered") | |
| st.title("๐ Seq2Seq Transformer Translator") | |
| st.write("An English-to-Spanish translation demo using a PyTorch Transformer built from scratch.") | |
| csv_filename = "data.csv" | |
| if not os.path.exists(csv_filename): | |
| st.error(f"Could not find `{csv_filename}` in the repository root directory! Please upload it to your Space.") | |
| st.stop() | |
| # Cache the dataset processing and model initialization so it only executes once | |
| def initialize_and_train(): | |
| df = pd.read_csv(csv_filename) | |
| eng_vocab = Vocabulary() | |
| eng_vocab.build_vocab(df['english']) | |
| spa_vocab = Vocabulary() | |
| spa_vocab.build_vocab(df['spanish']) | |
| dataset = TranslationDataset(df, eng_vocab, spa_vocab) | |
| dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=pad_collate_fn) | |
| model = PyTorchTransformer( | |
| src_vocab_size=eng_vocab.vocab_size, | |
| trg_vocab_size=spa_vocab.vocab_size | |
| ).to(device) | |
| criterion = nn.CrossEntropyLoss(ignore_index=0) | |
| optimizer = optim.Adam(model.parameters(), lr=0.0005) | |
| # Progress UI placeholder for compilation/training | |
| status_text = st.empty() | |
| status_text.info("๐ ๏ธ Training model on dataset pipeline, please wait...") | |
| model.train() | |
| for epoch in range(20): | |
| for src, trg in dataloader: | |
| src, trg = src.to(device), trg.to(device) | |
| trg_input = trg[:, :-1] | |
| trg_output = trg[:, 1:] | |
| optimizer.zero_grad() | |
| output = model(src, trg_input) | |
| loss = criterion(output.reshape(-1, output.shape[-1]), trg_output.reshape(-1)) | |
| loss.backward() | |
| optimizer.step() | |
| status_text.success("โ Model training complete and cached successfully!") | |
| return model, eng_vocab, spa_vocab | |
| # Load artifacts | |
| model, eng_vocab, spa_vocab = initialize_and_train() | |
| def translate_sentence(model, sentence, src_vocab, trg_vocab, max_len=10): | |
| model.eval() | |
| tokens = src_vocab.numericalize(sentence) | |
| src_tensor = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(device) | |
| trg_indices = [trg_vocab.w2i["<SOS>"]] | |
| for _ in range(max_len): | |
| trg_tensor = torch.tensor(trg_indices, dtype=torch.long).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| output = model(src_tensor, trg_tensor) | |
| best_guess = output.argmax(dim=-1)[:, -1].item() | |
| trg_indices.append(best_guess) | |
| if best_guess == trg_vocab.w2i["<EOS>"]: | |
| break | |
| translated_words = [trg_vocab.i2w[idx] for idx in trg_indices if idx not in [trg_vocab.w2i["<SOS>"], trg_vocab.w2i["<EOS>"]]] | |
| return " ".join(translated_words) | |
| # ========================================== | |
| # 5. USER INTERFACE INTERACTION | |
| # ========================================== | |
| st.markdown("---") | |
| user_input = st.text_input("Enter an English sentence to translate:", value="good morning") | |
| if st.button("Translate", type="primary"): | |
| if user_input.strip() == "": | |
| st.warning("Please enter a valid text segment.") | |
| else: | |
| with st.spinner("Decoding..."): | |
| translation = translate_sentence(model, user_input, eng_vocab, spa_vocab) | |
| st.markdown("### ๐ฏ Result:") | |
| st.success(f"**Spanish Translation:** {translation}") |