pico_translator / app.py
pnugues's picture
load from hub
7b6adeb verified
import torch
import torch.nn as nn
import math
import json
import gradio as gr
from huggingface_hub import hf_hub_download
REPO_ID = "pnugues/pico_translator"
MODEL_FILE = "pico_model.pth"
VOCAB_FILE = "pico_translator.vocab"
model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILE)
vocab_path = hf_hub_download(repo_id=REPO_ID, filename=VOCAB_FILE)
# Charger le vocabulaire
with open(vocab_path, 'r') as f:
token2idx = json.loads(f.read())
idx2token = {v: k for k, v in token2idx.items()}
# Paramètres
max_len = 100
VOCAB_SIZE = len(token2idx)
D_MODEL = 512
NHEAD = 8
DIM_FF = 512
BATCH_SIZE = 32
NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
MAX_LEN = max_len + 2
# Classe Embedding
class Embedding(nn.Module):
def __init__(self, vocab_size, d_model, dropout=0.1, max_len=max_len):
super().__init__()
self.d_model = d_model
self.max_len = max_len
self.input_embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
pe = self.pos_encoding(max_len, d_model)
self.pos_embedding = nn.Embedding.from_pretrained(pe, freeze=True)
self.dropout = nn.Dropout(dropout)
def forward(self, X):
pos_mat = torch.arange(X.size(-1), device=X.device)
X = self.input_embedding(X) * math.sqrt(self.d_model)
X += self.pos_embedding(pos_mat)
return self.dropout(X)
def pos_encoding(self, max_len, d_model):
dividend = torch.arange(max_len).unsqueeze(0).T
divisor = torch.pow(10000.0, torch.arange(0, d_model, 2) / d_model)
angles = dividend / divisor
pe = torch.zeros((max_len, d_model))
pe[:, 0::2] = torch.sin(angles)
pe[:, 1::2] = torch.cos(angles)
return pe
# Classe Translator
class Translator(nn.Module):
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6,
dim_feedforward=2048, dropout=0.1, vocab_size=30000, max_len=128):
super().__init__()
self.embedding = Embedding(vocab_size, d_model, max_len=max_len)
self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers,
dim_feedforward, dropout, batch_first=True)
self.fc = nn.Linear(d_model, vocab_size)
self.fc.weight = self.embedding.input_embedding.weight
def forward(self, src, tgt, src_padding, tgt_padding):
src_embs = self.embedding(src)
tgt_embs = self.embedding(tgt)
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1), device=src.device)
x = self.transformer(src_embs, tgt_embs, tgt_mask=tgt_mask,
src_key_padding_mask=src_padding, memory_key_padding_mask=src_padding,
tgt_key_padding_mask=tgt_padding)
return self.fc(x)
# Initialisation et chargement du modèle
DEVICE = torch.device("cpu") # Forcer CPU sur Spaces
model = Translator(d_model=D_MODEL, nhead=NHEAD, num_encoder_layers=NUM_ENCODER_LAYERS,
num_decoder_layers=NUM_DECODER_LAYERS, dim_feedforward=DIM_FF,
vocab_size=VOCAB_SIZE, max_len=MAX_LEN).to(DEVICE)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
# Fonctions auxiliaires
def seqs2tensors(seqs, token2idx):
tensors = []
for seq in seqs:
seq = ['<s>'] + list(seq) + ['</s>']
tensors += [torch.LongTensor(
[token2idx.get(x, 1) for x in seq])] # <unk> -> 1
return tensors
def tensors2seqs(tensors, idx2token):
seqs = []
for tensor in tensors:
seqs += [[idx2token.get(x.item(), '<unk>') for x in tensor]]
return seqs
def greedy_decode(model, src_seq, max_len):
src_embs = model.embedding(src_seq)
memory = model.transformer.encoder(src_embs)
tgt_seq = torch.LongTensor([token2idx['<s>']]).to(DEVICE)
tgt_embs = model.embedding(tgt_seq)
max_len = min(max_len, MAX_LEN)
for _ in range(max_len-1):
tgt_mask = nn.Transformer.generate_square_subsequent_mask(
tgt_embs.size(dim=0), device=DEVICE)
tgt_output = model.transformer.decoder(tgt_embs,
memory,
tgt_mask=tgt_mask)
char_prob = model.fc(tgt_output[-1])
next_char = torch.argmax(char_prob)
tgt_seq = torch.cat(
(tgt_seq,
torch.unsqueeze(next_char, dim=0)), dim=0)
tgt_embs = model.embedding(tgt_seq)
if next_char.item() == token2idx['</s>']:
break
return tgt_seq[1:]
def translate(src_sentence):
try:
src = seqs2tensors([src_sentence.strip()], token2idx)[0].to(DEVICE)
num_chars = src.size(0)
tgt_chars = greedy_decode(model, src, max_len=num_chars + 20)
tgt_chars = tensors2seqs([tgt_chars], idx2token)[0]
if tgt_chars[-1] == '</s>':
tgt_chars = tgt_chars[:-1]
tgt_str = ''.join(tgt_chars)
return tgt_str
except Exception as e:
return f"Error: {str(e)}"
# Interface Gradio
with gr.Blocks(title="Pico Translator") as demo:
gr.Markdown("# Pico Translator")
with gr.Row():
with gr.Column():
src_sentence = gr.Textbox(label="Source text in French", placeholder="Write your text...")
with gr.Column():
tgt_sentence = gr.Textbox(label="English translation", placeholder="Translation will show here...")
btn = gr.Button("Translate!")
btn.click(fn=translate, inputs=[src_sentence], outputs=[tgt_sentence])
demo.launch()