Spaces:
Sleeping
Sleeping
| import textwrap | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import spacy | |
| import random | |
| import pandas as pd | |
| from torch.utils.data import Dataset, DataLoader | |
| from torch.nn.utils.rnn import pad_sequence | |
| from sklearn.model_selection import train_test_split | |
| from flask import Flask ,request, jsonify,send_file,after_this_request | |
| from collections import Counter | |
| from flask_cors import CORS | |
| import requests | |
| import uuid | |
| import os | |
| import time | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| MAX_LEN = 350 | |
| BATCH_SIZE = 8 | |
| EMB_SIZE = 128 | |
| NHEAD = 8 | |
| FFN_HID_DIM = 256 | |
| NUM_ENCODER_LAYERS = 4 | |
| NUM_DECODER_LAYERS = 4 | |
| NUM_EPOCHS = 18 | |
| MIN_FREQ = 2 | |
| PORT = 7680 | |
| # ==== Tokenizers ==== | |
| spacy_eng = spacy.load("en_core_web_sm") | |
| def tokenize_en(text): | |
| return [tok.text.lower() for tok in spacy_eng.tokenizer(text)] | |
| def tokenize_te(text): | |
| return text.strip().split(" ") | |
| # ==== Vocab Builder ==== | |
| def build_vocab(sentences, tokenizer, min_freq): | |
| counter = Counter() | |
| for sent in sentences: | |
| counter.update(tokenizer(sent)) | |
| vocab = {'<pad>': 0, '<sos>': 1, '<eos>': 2, '<unk>': 3} | |
| for word, freq in counter.items(): | |
| if freq >= min_freq: | |
| vocab[word] = len(vocab) | |
| return vocab | |
| # ==== Dataset ==== | |
| class TranslationDataset(Dataset): | |
| def __init__(self, df, en_vocab, te_vocab): | |
| self.data = df | |
| self.en_vocab = en_vocab | |
| self.te_vocab = te_vocab | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| en = self.data.iloc[idx]['response'] | |
| te = self.data.iloc[idx]['translated_response'] | |
| en_tokens = ['<sos>'] + tokenize_en(en) + ['<eos>'] | |
| te_tokens = ['<sos>'] + tokenize_te(te) + ['<eos>'] | |
| en_ids = [self.en_vocab.get(tok, self.en_vocab['<unk>']) for tok in en_tokens] | |
| te_ids = [self.te_vocab.get(tok, self.te_vocab['<unk>']) for tok in te_tokens] | |
| return torch.tensor(en_ids), torch.tensor(te_ids) | |
| # ==== Collate Function ==== | |
| def collate_fn(batch): | |
| src_batch, tgt_batch = zip(*batch) | |
| src_batch = pad_sequence(src_batch, padding_value=en_vocab['<pad>'], batch_first=True) | |
| tgt_batch = pad_sequence(tgt_batch, padding_value=te_vocab['<pad>'], batch_first=True) | |
| return src_batch, tgt_batch | |
| # ==== Transformer Model ==== | |
| class Seq2SeqTransformer(nn.Module): | |
| def __init__(self, num_encoder_layers, num_decoder_layers, | |
| emb_size, src_vocab_size, tgt_vocab_size, | |
| nhead, dim_feedforward=512, dropout=0.1): | |
| super().__init__() | |
| self.transformer = nn.Transformer(d_model=emb_size, nhead=nhead, | |
| num_encoder_layers=num_encoder_layers, | |
| num_decoder_layers=num_decoder_layers, | |
| dim_feedforward=dim_feedforward, dropout=dropout) | |
| self.src_tok_emb = nn.Embedding(src_vocab_size, emb_size) | |
| self.tgt_tok_emb = nn.Embedding(tgt_vocab_size, emb_size) | |
| self.fc_out = nn.Linear(emb_size, tgt_vocab_size) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, src, tgt): | |
| src_mask = self.transformer.generate_square_subsequent_mask(src.size(1)).to(DEVICE) | |
| tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.size(1)).to(DEVICE) | |
| src_emb = self.dropout(self.src_tok_emb(src)) | |
| tgt_emb = self.dropout(self.tgt_tok_emb(tgt)) | |
| outs = self.transformer(src_emb.permute(1,0,2), tgt_emb.permute(1,0,2), | |
| src_mask=src_mask, tgt_mask=tgt_mask) | |
| return self.fc_out(outs.permute(1,0,2)) | |
| def translate(model, sentence, en_vocab, te_vocab, te_inv_vocab, max_len=MAX_LEN): | |
| model.eval() | |
| tokens = ['<sos>'] + tokenize_en(sentence) + ['<eos>'] | |
| src_ids = torch.tensor([[en_vocab.get(t, en_vocab['<unk>']) for t in tokens]]).to(DEVICE) | |
| tgt_ids = torch.tensor([[te_vocab['<sos>']]]).to(DEVICE) | |
| for i in range(max_len): | |
| out = model(src_ids, tgt_ids) | |
| next_token = out.argmax(-1)[:, -1].item() | |
| tgt_ids = torch.cat([tgt_ids, torch.tensor([[next_token]]).to(DEVICE)], dim=1) | |
| if next_token == te_vocab['<eos>']: | |
| break | |
| translated = [te_inv_vocab[idx.item()] for idx in tgt_ids[0][1:]] | |
| return ' '.join(translated[:-1]) if translated[-1] == '<eos>' else ' '.join(translated) | |
| # ==== Load Data ==== | |
| df_telugu = pd.read_csv("merged_translated_responses.csv") # columns: 'en', 'te' | |
| # Clean NaN or non-string entries | |
| df_telugu = df_telugu.dropna(subset=['response', 'translated_response']) | |
| # Ensure all entries are strings | |
| df_telugu['response'] = df_telugu['response'].astype(str) | |
| df_telugu['translated_response'] = df_telugu['translated_response'].astype(str) | |
| # Build vocabularies | |
| en_vocab = build_vocab(df_telugu['response'], tokenize_en, MIN_FREQ) | |
| te_vocab = build_vocab(df_telugu['translated_response'], tokenize_te, MIN_FREQ) | |
| te_inv_vocab = {idx: tok for tok, idx in te_vocab.items()} | |
| # Prepare Dataset & DataLoader | |
| dataset = TranslationDataset(df_telugu, en_vocab, te_vocab) | |
| dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn) | |
| # Initialize Model | |
| model = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, | |
| len(en_vocab), len(te_vocab), NHEAD, FFN_HID_DIM).to(DEVICE) | |
| pad_idx = te_vocab['<pad>'] | |
| criterion_telugu = nn.CrossEntropyLoss(ignore_index=pad_idx) | |
| optimizer_telugu = optim.Adam(model.parameters(), lr=0.0005) | |
| # ==== Training ==== | |
| # for epoch in range(NUM_EPOCHS): | |
| # loss = train(model, dataloader, optimizer, criterion) | |
| # print(f"Epoch {epoch+1}, Loss: {loss:.4f}") | |
| # ==== Try Translation ==== | |
| model_telugu = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,len(en_vocab), len(te_vocab), NHEAD, FFN_HID_DIM).to(DEVICE) | |
| # Load saved weights | |
| model_telugu.load_state_dict(torch.load("english_telugu_transformer.pth",map_location = torch.device('cpu'))) | |
| model_telugu.eval() | |
| app=Flask(__name__) | |
| CORS(app) | |
| def home(): | |
| return jsonify({"message": "hellooooooooo"}) | |
| def translate_text(): | |
| data = request.get_json() | |
| text = data.get("text", "") | |
| if not text: | |
| return jsonify({"error": "Text cannot be empty"}), 400 | |
| # First generate English response | |
| english_response = text | |
| start=time.time() | |
| # Then translate to Telugu | |
| telugu_response = translate(model_telugu, english_response, en_vocab, te_vocab, te_inv_vocab) | |
| end=time.time() | |
| return jsonify({ | |
| "english": english_response, | |
| "telugu": telugu_response, | |
| "time": end-start | |
| }) | |