File size: 2,232 Bytes
d946d7b 29703de | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | from pydantic import BaseModel
from io import BytesIO
import requests
from model import TransformerSeq2Seq,translate
from utils import load_tokenizers_and_embeddings
import torch
# class mô hình của bạn
# ===== 1. Load model và tokenizer khi khởi động server =====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ===== Load 1 lần khi start server =====
resources = load_tokenizers_and_embeddings()
tokenizer_vi = resources["tokenizer_vi"]
embedding_matrix_vi = resources["embedding_vi"]
tokenizer_en = resources["tokenizer_en"]
embedding_matrix_en = resources["embedding_en"]
device = resources["device"]
print("✅ Tokenizers & embeddings loaded!")
if isinstance(embedding_matrix_en, torch.Tensor):
embed_dim = embedding_matrix_en.size(1)
else: # nn.Embedding
embed_dim = embedding_matrix_en.embedding_dim
max_len = 128
batch_size = 32
# Load model
model = TransformerSeq2Seq(
embed_dim=embed_dim,
vocab_size=tokenizer_vi.vocab_size, # hoặc len(tokenizer_vi)
embedding_decoder=embedding_matrix_vi, # embedding target đã có sẵn
num_heads=4,
num_layers=2,
dim_feedforward=256,
dropout=0.1,
freeze_decoder_emb=True,
max_len=max_len
)
MODEL_URL = "https://huggingface.co/nemabruh404/Machine_Translation/resolve/main/model_state_dict.pt"
# Fetch model từ Hub
checkpoint_bytes = BytesIO(requests.get(MODEL_URL).content)
checkpoint = torch.load(checkpoint_bytes, map_location=device)
# Load state dict
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
print("✅ Model loaded from Hugging Face Hub")
print("Model loaded")
def hf_inference_fn(inputs: str):
return translate(
model=model,
src_sentence=inputs,
tokenizer_src=tokenizer_en, # tiếng Anh -> input
tokenizer_tgt=tokenizer_vi, # tiếng Việt -> output
embedding_src=embedding_matrix_en,
device=device,
max_len=max_len
)
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
class TranslateRequest(BaseModel):
text: str
@app.post("/translate")
def translate(req: TranslateRequest):
return {"translation": hf_inference_fn(req.text)}
|