# app.py import os import gradio as gr import torch from tokenizers import Tokenizer from huggingface_hub import hf_hub_download from src.config import Config from src.model import TranslateModel # Configurable sources (env vars) with local-first fallback HF_REPO_ID = os.getenv("HF_REPO_ID", "caixiaoshun/tiny-translator-zh2en") HF_CKPT_FILE = os.getenv("CKPT_FILE", "translate-step=290000.ckpt") HF_TOKENIZER_FILE = os.getenv("TOKENIZER_FILE", "tokenizer.json") LOCAL_CKPT_PATH = os.getenv("LOCAL_CKPT_PATH", "checkpoints/translate-step=290000.ckpt") LOCAL_TOKENIZER_PATH = os.getenv("LOCAL_TOKENIZER_PATH", "checkpoints/tokenizer.json") class Inference: def __init__(self, config: Config, ckpt_path: str): self.config = config self.device = "cuda" if torch.cuda.is_available() else "cpu" # tokenizer (local-first, else hub) tokenizer_path = ( LOCAL_TOKENIZER_PATH if os.path.exists(LOCAL_TOKENIZER_PATH) else hf_hub_download(repo_id=HF_REPO_ID, filename=HF_TOKENIZER_FILE) ) self.tokenizer: Tokenizer = Tokenizer.from_file(tokenizer_path) self.id_SOS = self.tokenizer.token_to_id("[SOS]") self.id_EOS = self.tokenizer.token_to_id("[EOS]") self.id_PAD = self.tokenizer.token_to_id("[PAD]") # model self.model: TranslateModel = TranslateModel(config) # ckpt (local-first, else hub) ckpt_resolved = ( LOCAL_CKPT_PATH if os.path.exists(LOCAL_CKPT_PATH) else hf_hub_download(repo_id=HF_REPO_ID, filename=HF_CKPT_FILE) ) ckpt = torch.load(ckpt_resolved, map_location="cpu")["state_dict"] prefix = "net._orig_mod." state_dict = {} for k, v in ckpt.items(): new_k = k[len(prefix):] if k.startswith(prefix) else k state_dict[new_k] = v self.model.load_state_dict(state_dict, strict=True) self.model.to(self.device).eval() @torch.no_grad() def greedy(self, src_ids, max_len): src = torch.tensor(src_ids, dtype=torch.long, device=self.device).unsqueeze(0) tgt = torch.tensor([[self.id_SOS]], dtype=torch.long, device=self.device) src_pad_mask = (src != self.id_PAD) if (self.id_PAD is not None) else None for _ in range(1, max_len): logits = self.model(src, tgt, src_pad_mask=src_pad_mask)[:, -1, :] index = torch.argmax(logits, dim=-1) # [1] tgt = torch.cat([tgt, index.unsqueeze(-1)], dim=-1) if self.id_EOS is not None and index.item() == self.id_EOS: break return tgt.squeeze(0).tolist() @torch.no_grad() def top_p(self, src_ids, max_len, top_p=0.9, temperature=1.0): src = torch.tensor(src_ids, dtype=torch.long, device=self.device).unsqueeze(0) tgt = torch.tensor([[self.id_SOS]], dtype=torch.long, device=self.device) src_pad_mask = (src != self.id_PAD) if (self.id_PAD is not None) else None for _ in range(1, max_len): logits = self.model(src, tgt, src_pad_mask=src_pad_mask)[:, -1, :] if temperature != 1.0: logits = logits / temperature probs = torch.softmax(logits, dim=-1) sorted_probs, sorted_idx = torch.sort(probs, descending=True) cumsum = torch.cumsum(sorted_probs, dim=-1) mask = cumsum > top_p mask[..., 0] = False filtered = sorted_probs.masked_fill(mask, 0.0) filtered = filtered / filtered.sum(dim=-1, keepdim=True) next_sorted = torch.multinomial(filtered, 1) # [1,1] next_id = sorted_idx.gather(-1, next_sorted) tgt = torch.cat([tgt, next_id], dim=-1) if self.id_EOS is not None and next_id.item() == self.id_EOS: break return tgt.squeeze(0).tolist() @torch.no_grad() def beam_search(self, src_ids, max_len, beam=4, len_penalty=0.6): src = torch.tensor(src_ids, dtype=torch.long, device=self.device).unsqueeze(0) src_pad_mask = (src != self.id_PAD) if (self.id_PAD is not None) else None beams = [(torch.tensor([[self.id_SOS]], device=self.device), 0.0)] for _ in range(1, max_len): new_beams = [] for seq, logp in beams: if self.id_EOS is not None and seq[0, -1].item() == self.id_EOS: new_beams.append((seq, logp)) continue logits = self.model(src, seq, src_pad_mask=src_pad_mask)[:, -1, :] logprobs = torch.log_softmax(logits, dim=-1) topk_logp, topk_idx = torch.topk(logprobs, beam, dim=-1) for k in range(beam): next_id = topk_idx[0, k].view(1, 1) next_seq = torch.cat([seq, next_id], dim=-1) new_beams.append((next_seq, logp + topk_logp[0, k].item())) def score_fn(s, lp): L = s.size(1) return lp / ((5 + L) ** len_penalty / (5 + 1) ** len_penalty) new_beams.sort(key=lambda x: score_fn(x[0], x[1]), reverse=True) beams = new_beams[:beam] if all(seq[0, -1].item() == self.id_EOS for seq, _ in beams if self.id_EOS is not None): break return beams[0][0].squeeze(0).tolist() def postprocess(self, ids): if self.id_SOS is not None and ids and ids[0] == self.id_SOS: ids = ids[1:] if self.id_EOS is not None and self.id_EOS in ids: ids = ids[:ids.index(self.id_EOS)] text = self.tokenizer.decode(ids).strip() return text def translate( self, text, method="greedy", max_tokens=128, top_p_val=0.9, temperature=1.0, beam=4, len_penalty=0.6, ): src_ids = self.tokenizer.encode(text).ids max_len = min(max_tokens, self.config.max_len) if method == "greedy": ids = self.greedy(src_ids, max_len) elif method == "top-p": ids = self.top_p(src_ids, max_len, top_p_val, temperature) elif method == "beam": ids = self.beam_search(src_ids, max_len, beam, len_penalty) else: return f"未知解码方法: {method}" return self.postprocess(ids) # 初始化模型 config = Config() inference = Inference(config, LOCAL_CKPT_PATH) def translate_api(src_text, method, max_tokens, top_p, temperature, beam, len_penalty): return inference.translate( src_text, method=method, max_tokens=max_tokens, top_p_val=top_p, temperature=temperature, beam=beam, len_penalty=len_penalty, ) demo = gr.Interface( fn=translate_api, inputs=[ gr.Textbox(label="源文本", placeholder="请输入要翻译的文本", lines=4), gr.Radio(choices=["greedy", "top-p", "beam"], value="greedy", label="解码方法"), gr.Slider(8, 512, value=128, step=1, label="最大生成长度"), gr.Slider(0.5, 1.0, value=0.9, step=0.01, label="Top-p (仅 top-p 有效)"), gr.Slider(0.1, 2.0, value=1.0, step=0.05, label="温度 (仅 top-p 有效)"), gr.Slider(1, 10, value=4, step=1, label="Beam size (仅 beam 有效)"), gr.Slider(0.0, 2.0, value=0.6, step=0.05, label="Length penalty (仅 beam 有效)"), ], outputs=gr.Textbox(label="译文", lines=6), title="Tiny Translator 翻译", ) if __name__ == "__main__": port = int(os.getenv("PORT", "7860")) demo.queue().launch(server_name="0.0.0.0", server_port=port, share=False)