File size: 2,232 Bytes
d946d7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29703de
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from pydantic import BaseModel
from io import BytesIO
import requests
from model import TransformerSeq2Seq,translate
from utils import load_tokenizers_and_embeddings

import torch

   # class mô hình của bạn

# ===== 1. Load model và tokenizer khi khởi động server =====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ===== Load 1 lần khi start server =====
resources = load_tokenizers_and_embeddings()
tokenizer_vi = resources["tokenizer_vi"]
embedding_matrix_vi = resources["embedding_vi"]
tokenizer_en = resources["tokenizer_en"]
embedding_matrix_en = resources["embedding_en"]
device = resources["device"]

print("✅ Tokenizers & embeddings loaded!")
if isinstance(embedding_matrix_en, torch.Tensor):
    embed_dim = embedding_matrix_en.size(1)
else:  # nn.Embedding
    embed_dim = embedding_matrix_en.embedding_dim
max_len = 128
batch_size = 32
# Load model
model = TransformerSeq2Seq(
    embed_dim=embed_dim,
    vocab_size=tokenizer_vi.vocab_size,   # hoặc len(tokenizer_vi)
    embedding_decoder=embedding_matrix_vi,      # embedding target đã có sẵn
    num_heads=4,
    num_layers=2,
    dim_feedforward=256,
    dropout=0.1,
    freeze_decoder_emb=True,
    max_len=max_len
)
MODEL_URL = "https://huggingface.co/nemabruh404/Machine_Translation/resolve/main/model_state_dict.pt"

# Fetch model từ Hub
checkpoint_bytes = BytesIO(requests.get(MODEL_URL).content)
checkpoint = torch.load(checkpoint_bytes, map_location=device)

# Load state dict
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()

print("✅ Model loaded from Hugging Face Hub")
print("Model loaded")

def hf_inference_fn(inputs: str):
    return translate(
        model=model,
        src_sentence=inputs,
        tokenizer_src=tokenizer_en,   # tiếng Anh -> input
        tokenizer_tgt=tokenizer_vi,   # tiếng Việt -> output
        embedding_src=embedding_matrix_en,
        device=device,
        max_len=max_len
    )
from fastapi import FastAPI
from pydantic import BaseModel

app = FastAPI()

class TranslateRequest(BaseModel):
    text: str

@app.post("/translate")
def translate(req: TranslateRequest):
    return {"translation": hf_inference_fn(req.text)}