| | import gradio as gr |
| | import torch |
| | import torch.nn as nn |
| | import json |
| | import math |
| | import re |
| |
|
| | |
| | with open("vocabulary.json", "r") as f: |
| | vocab = json.load(f) |
| |
|
| | |
| | class Config: |
| | vocab_size = 12006 |
| | max_length = 100 |
| | embed_dim = 256 |
| | num_heads = 8 |
| | num_layers = 2 |
| | feedforward_dim = 512 |
| | dropout = 0.1 |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | config = Config() |
| |
|
| | |
| | class PositionalEncoding(nn.Module): |
| | def __init__(self, embed_dim, max_len=100): |
| | super(PositionalEncoding, self).__init__() |
| | pe = torch.zeros(max_len, embed_dim) |
| | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| | div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim)) |
| | pe[:, 0::2] = torch.sin(position * div_term) |
| | pe[:, 1::2] = torch.cos(position * div_term) |
| | self.pe = pe.unsqueeze(0) |
| |
|
| | def forward(self, x): |
| | return x + self.pe[:, :x.size(1)].to(x.device) |
| |
|
| | |
| | class Seq2SeqTransformer(nn.Module): |
| | def __init__(self, config): |
| | super(Seq2SeqTransformer, self).__init__() |
| | self.embedding = nn.Embedding(config.vocab_size, config.embed_dim) |
| | self.positional_encoding = PositionalEncoding(config.embed_dim, config.max_length) |
| | self.transformer = nn.Transformer( |
| | d_model=config.embed_dim, |
| | nhead=config.num_heads, |
| | num_encoder_layers=config.num_layers, |
| | num_decoder_layers=config.num_layers, |
| | dim_feedforward=config.feedforward_dim, |
| | dropout=config.dropout |
| | ) |
| | self.fc_out = nn.Linear(config.embed_dim, config.vocab_size) |
| |
|
| | def forward(self, src, tgt): |
| | src_emb = self.embedding(src) * math.sqrt(config.embed_dim) |
| | tgt_emb = self.embedding(tgt) * math.sqrt(config.embed_dim) |
| | src_emb = self.positional_encoding(src_emb) |
| | tgt_emb = self.positional_encoding(tgt_emb) |
| | out = self.transformer(src_emb.permute(1, 0, 2), tgt_emb.permute(1, 0, 2)) |
| | out = self.fc_out(out.permute(1, 0, 2)) |
| | return out |
| |
|
| | |
| | def load_model(path): |
| | model = Seq2SeqTransformer(config).to(config.device) |
| | model.load_state_dict(torch.load(path, map_location=config.device)) |
| | model.eval() |
| | return model |
| |
|
| | cpp_to_pseudo_model = load_model("cpp_to_pseudo_epoch_1.pth") |
| | pseudo_to_cpp_model = load_model("transformer_epoch_1.pth") |
| |
|
| | |
| | def is_valid_output(output): |
| | """ |
| | Check if the generated output is valid. |
| | - Ensures it contains meaningful tokens. |
| | - Filters out repetitive characters/symbols. |
| | """ |
| | if re.search(r"([+=*()\-]){5,}", output): |
| | return False |
| | if output.count("<unk>") > 2: |
| | return False |
| | return True |
| |
|
| | def is_meaningful_translation(output): |
| | words = output.split() |
| | valid_words = [word for word in words if word.isalnum()] |
| | return len(valid_words) > 2 |
| |
|
| | def translate_with_check(model, input_tokens, vocab, device, max_length=50): |
| | output = translate(model, input_tokens, vocab, device, max_length) |
| | |
| | if not is_valid_output(output): |
| | return "⚠ Invalid translation detected. Try refining the input." |
| | if not is_meaningful_translation(output): |
| | return "⚠ Translation is not meaningful. Try a different input." |
| | return output |
| |
|
| | |
| | def translate(model, input_tokens, vocab, device, max_length=50): |
| | model.eval() |
| | input_ids = [vocab.get(token, vocab["<unk>"]) for token in input_tokens] |
| | input_tensor = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device) |
| | output_ids = [vocab["<start>"]] |
| | for _ in range(max_length): |
| | output_tensor = torch.tensor(output_ids, dtype=torch.long).unsqueeze(0).to(device) |
| | with torch.no_grad(): |
| | predictions = model(input_tensor, output_tensor) |
| | next_token_id = predictions.argmax(dim=-1)[:, -1].item() |
| | output_ids.append(next_token_id) |
| | if next_token_id == vocab["<end>"]: |
| | break |
| | id_to_token = {idx: token for token, idx in vocab.items()} |
| | return " ".join([id_to_token.get(idx, "<unk>") for idx in output_ids[1:]]) |
| |
|
| | |
| | with gr.Blocks() as demo: |
| | gr.Markdown("# C++ & Pseudocode Translator") |
| | with gr.Row(): |
| | input_text = gr.Textbox(label="Enter code:") |
| | mode = gr.Radio(["C++ → Pseudocode", "Pseudocode → C++"], label="Translation Mode") |
| | output_text = gr.Textbox(label="Translated Output") |
| | translate_button = gr.Button("Translate") |
| |
|
| | def translate_text(input_text, mode): |
| | tokens = input_text.strip().split() |
| | model = cpp_to_pseudo_model if mode == "C++ → Pseudocode" else pseudo_to_cpp_model |
| | return translate_with_check(model, tokens, vocab, config.device) |
| |
|
| | translate_button.click(translate_text, inputs=[input_text, mode], outputs=output_text) |
| |
|
| | demo.launch() |