Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import math | |
| import json | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| REPO_ID = "pnugues/pico_translator" | |
| MODEL_FILE = "pico_model.pth" | |
| VOCAB_FILE = "pico_translator.vocab" | |
| model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE) | |
| vocab_path = hf_hub_download(repo_id=REPO_ID, filename=VOCAB_FILE) | |
| # Charger le vocabulaire | |
| with open(vocab_path, 'r') as f: | |
| token2idx = json.loads(f.read()) | |
| idx2token = {v: k for k, v in token2idx.items()} | |
| # Paramètres | |
| max_len = 100 | |
| VOCAB_SIZE = len(token2idx) | |
| D_MODEL = 512 | |
| NHEAD = 8 | |
| DIM_FF = 512 | |
| BATCH_SIZE = 32 | |
| NUM_ENCODER_LAYERS = 3 | |
| NUM_DECODER_LAYERS = 3 | |
| MAX_LEN = max_len + 2 | |
| # Classe Embedding | |
| class Embedding(nn.Module): | |
| def __init__(self, vocab_size, d_model, dropout=0.1, max_len=max_len): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.max_len = max_len | |
| self.input_embedding = nn.Embedding(vocab_size, d_model, padding_idx=0) | |
| pe = self.pos_encoding(max_len, d_model) | |
| self.pos_embedding = nn.Embedding.from_pretrained(pe, freeze=True) | |
| self.dropout = nn.Dropout(dropout) | |
| def forward(self, X): | |
| pos_mat = torch.arange(X.size(-1), device=X.device) | |
| X = self.input_embedding(X) * math.sqrt(self.d_model) | |
| X += self.pos_embedding(pos_mat) | |
| return self.dropout(X) | |
| def pos_encoding(self, max_len, d_model): | |
| dividend = torch.arange(max_len).unsqueeze(0).T | |
| divisor = torch.pow(10000.0, torch.arange(0, d_model, 2) / d_model) | |
| angles = dividend / divisor | |
| pe = torch.zeros((max_len, d_model)) | |
| pe[:, 0::2] = torch.sin(angles) | |
| pe[:, 1::2] = torch.cos(angles) | |
| return pe | |
| # Classe Translator | |
| class Translator(nn.Module): | |
| def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, | |
| dim_feedforward=2048, dropout=0.1, vocab_size=30000, max_len=128): | |
| super().__init__() | |
| self.embedding = Embedding(vocab_size, d_model, max_len=max_len) | |
| self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, | |
| dim_feedforward, dropout, batch_first=True) | |
| self.fc = nn.Linear(d_model, vocab_size) | |
| self.fc.weight = self.embedding.input_embedding.weight | |
| def forward(self, src, tgt, src_padding, tgt_padding): | |
| src_embs = self.embedding(src) | |
| tgt_embs = self.embedding(tgt) | |
| tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1), device=src.device) | |
| x = self.transformer(src_embs, tgt_embs, tgt_mask=tgt_mask, | |
| src_key_padding_mask=src_padding, memory_key_padding_mask=src_padding, | |
| tgt_key_padding_mask=tgt_padding) | |
| return self.fc(x) | |
| # Initialisation et chargement du modèle | |
| DEVICE = torch.device("cpu") # Forcer CPU sur Spaces | |
| model = Translator(d_model=D_MODEL, nhead=NHEAD, num_encoder_layers=NUM_ENCODER_LAYERS, | |
| num_decoder_layers=NUM_DECODER_LAYERS, dim_feedforward=DIM_FF, | |
| vocab_size=VOCAB_SIZE, max_len=MAX_LEN).to(DEVICE) | |
| model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
| model.eval() | |
| # Fonctions auxiliaires | |
| def seqs2tensors(seqs, token2idx): | |
| tensors = [] | |
| for seq in seqs: | |
| seq = ['<s>'] + list(seq) + ['</s>'] | |
| tensors += [torch.LongTensor( | |
| [token2idx.get(x, 1) for x in seq])] # <unk> -> 1 | |
| return tensors | |
| def tensors2seqs(tensors, idx2token): | |
| seqs = [] | |
| for tensor in tensors: | |
| seqs += [[idx2token.get(x.item(), '<unk>') for x in tensor]] | |
| return seqs | |
| def greedy_decode(model, src_seq, max_len): | |
| src_embs = model.embedding(src_seq) | |
| memory = model.transformer.encoder(src_embs) | |
| tgt_seq = torch.LongTensor([token2idx['<s>']]).to(DEVICE) | |
| tgt_embs = model.embedding(tgt_seq) | |
| max_len = min(max_len, MAX_LEN) | |
| for _ in range(max_len-1): | |
| tgt_mask = nn.Transformer.generate_square_subsequent_mask( | |
| tgt_embs.size(dim=0), device=DEVICE) | |
| tgt_output = model.transformer.decoder(tgt_embs, | |
| memory, | |
| tgt_mask=tgt_mask) | |
| char_prob = model.fc(tgt_output[-1]) | |
| next_char = torch.argmax(char_prob) | |
| tgt_seq = torch.cat( | |
| (tgt_seq, | |
| torch.unsqueeze(next_char, dim=0)), dim=0) | |
| tgt_embs = model.embedding(tgt_seq) | |
| if next_char.item() == token2idx['</s>']: | |
| break | |
| return tgt_seq[1:] | |
| def translate(src_sentence): | |
| try: | |
| src = seqs2tensors([src_sentence.strip()], token2idx)[0].to(DEVICE) | |
| num_chars = src.size(0) | |
| tgt_chars = greedy_decode(model, src, max_len=num_chars + 20) | |
| tgt_chars = tensors2seqs([tgt_chars], idx2token)[0] | |
| if tgt_chars[-1] == '</s>': | |
| tgt_chars = tgt_chars[:-1] | |
| tgt_str = ''.join(tgt_chars) | |
| return tgt_str | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Interface Gradio | |
| with gr.Blocks(title="Pico Translator") as demo: | |
| gr.Markdown("# Pico Translator") | |
| with gr.Row(): | |
| with gr.Column(): | |
| src_sentence = gr.Textbox(label="Source text in French", placeholder="Write your text...") | |
| with gr.Column(): | |
| tgt_sentence = gr.Textbox(label="English translation", placeholder="Translation will show here...") | |
| btn = gr.Button("Translate!") | |
| btn.click(fn=translate, inputs=[src_sentence], outputs=[tgt_sentence]) | |
| demo.launch() |