nemabruh404 commited on
Commit
d946d7b
·
verified ·
1 Parent(s): 0d0d69a

Rename app.py to inference.py

Browse files
Files changed (1) hide show
  1. app.py → inference.py +64 -71
app.py → inference.py RENAMED
@@ -1,71 +1,64 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from io import BytesIO
4
- import requests
5
- from model import TransformerSeq2Seq,translate
6
- from utils import load_tokenizers_and_embeddings
7
-
8
- import torch
9
-
10
- # class mô hình của bạn
11
-
12
- app = FastAPI()
13
-
14
- # ===== 1. Load model tokenizer khi khởi động server =====
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
-
17
- # ===== Load 1 lần khi start server =====
18
- resources = load_tokenizers_and_embeddings()
19
- tokenizer_vi = resources["tokenizer_vi"]
20
- embedding_matrix_vi = resources["embedding_vi"]
21
- tokenizer_en = resources["tokenizer_en"]
22
- embedding_matrix_en = resources["embedding_en"]
23
- device = resources["device"]
24
-
25
- print("✅ Tokenizers & embeddings loaded!")
26
- if isinstance(embedding_matrix_en, torch.Tensor):
27
- embed_dim = embedding_matrix_en.size(1)
28
- else: # nn.Embedding
29
- embed_dim = embedding_matrix_en.embedding_dim
30
- max_len = 128
31
- batch_size = 32
32
- # Load model
33
- model = TransformerSeq2Seq(
34
- embed_dim=embed_dim,
35
- vocab_size=tokenizer_vi.vocab_size, # hoặc len(tokenizer_vi)
36
- embedding_decoder=embedding_matrix_vi, # embedding target đã có sẵn
37
- num_heads=4,
38
- num_layers=2,
39
- dim_feedforward=256,
40
- dropout=0.1,
41
- freeze_decoder_emb=True,
42
- max_len=max_len
43
- )
44
- MODEL_URL = "https://huggingface.co/nemabruh404/Machine_Translation/resolve/main/model_state_dict.pt"
45
-
46
- # Fetch model từ Hub
47
- checkpoint_bytes = BytesIO(requests.get(MODEL_URL).content)
48
- checkpoint = torch.load(checkpoint_bytes, map_location=device)
49
-
50
- # Load state dict
51
- model.load_state_dict(checkpoint["model_state_dict"])
52
- model.to(device)
53
- model.eval()
54
-
55
- print("✅ Model loaded from Hugging Face Hub")
56
- print("Model loaded")
57
- class TranslationRequest(BaseModel):
58
- text: str
59
- # ===== Endpoint dịch =====
60
- @app.post("/translate")
61
- def translate_api(req: TranslationRequest):
62
- output = translate(
63
- model=model,
64
- src_sentence=req.text,
65
- tokenizer_src=tokenizer_en, # tiếng Anh -> input
66
- tokenizer_tgt=tokenizer_vi, # tiếng Việt -> output
67
- embedding_src=embedding_matrix_en,
68
- device=device,
69
- max_len=max_len
70
- )
71
- return {"input": req.text, "translation": output}
 
1
+ from pydantic import BaseModel
2
+ from io import BytesIO
3
+ import requests
4
+ from model import TransformerSeq2Seq,translate
5
+ from utils import load_tokenizers_and_embeddings
6
+
7
+ import torch
8
+
9
+ # class mô hình của bạn
10
+
11
+ # ===== 1. Load model và tokenizer khi khởi động server =====
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ # ===== Load 1 lần khi start server =====
15
+ resources = load_tokenizers_and_embeddings()
16
+ tokenizer_vi = resources["tokenizer_vi"]
17
+ embedding_matrix_vi = resources["embedding_vi"]
18
+ tokenizer_en = resources["tokenizer_en"]
19
+ embedding_matrix_en = resources["embedding_en"]
20
+ device = resources["device"]
21
+
22
+ print("✅ Tokenizers & embeddings loaded!")
23
+ if isinstance(embedding_matrix_en, torch.Tensor):
24
+ embed_dim = embedding_matrix_en.size(1)
25
+ else: # nn.Embedding
26
+ embed_dim = embedding_matrix_en.embedding_dim
27
+ max_len = 128
28
+ batch_size = 32
29
+ # Load model
30
+ model = TransformerSeq2Seq(
31
+ embed_dim=embed_dim,
32
+ vocab_size=tokenizer_vi.vocab_size, # hoặc len(tokenizer_vi)
33
+ embedding_decoder=embedding_matrix_vi, # embedding target đã có sẵn
34
+ num_heads=4,
35
+ num_layers=2,
36
+ dim_feedforward=256,
37
+ dropout=0.1,
38
+ freeze_decoder_emb=True,
39
+ max_len=max_len
40
+ )
41
+ MODEL_URL = "https://huggingface.co/nemabruh404/Machine_Translation/resolve/main/model_state_dict.pt"
42
+
43
+ # Fetch model từ Hub
44
+ checkpoint_bytes = BytesIO(requests.get(MODEL_URL).content)
45
+ checkpoint = torch.load(checkpoint_bytes, map_location=device)
46
+
47
+ # Load state dict
48
+ model.load_state_dict(checkpoint["model_state_dict"])
49
+ model.to(device)
50
+ model.eval()
51
+
52
+ print("✅ Model loaded from Hugging Face Hub")
53
+ print("Model loaded")
54
+
55
+ def hf_inference_fn(inputs: str):
56
+ return translate(
57
+ model=model,
58
+ src_sentence=inputs,
59
+ tokenizer_src=tokenizer_en, # tiếng Anh -> input
60
+ tokenizer_tgt=tokenizer_vi, # tiếng Việt -> output
61
+ embedding_src=embedding_matrix_en,
62
+ device=device,
63
+ max_len=max_len
64
+ )