Kush26 commited on
Commit
8515233
·
verified ·
1 Parent(s): 2ec63b7

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +102 -101
app/main.py CHANGED
@@ -1,102 +1,103 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- import torch
4
- import torch.nn.functional as F
5
- from tokenizers import Tokenizer
6
- from huggingface_hub import hf_hub_download
7
-
8
- from .model_def import BuildTransformer
9
-
10
-
11
- app = FastAPI(title="Hindi-English Translator API")
12
- model = None
13
- tokenizer = None
14
- device = torch.device("cpu")
15
-
16
-
17
- class TranslationRequest(BaseModel):
18
- text: str
19
-
20
- class TranslationResponse(BaseModel):
21
- translated_text: str
22
-
23
- @app.on_event("startup")
24
- def load_assets():
25
-
26
- global model, tokenizer, device
27
-
28
- model_file = hf_hub_download(repo_id="Kush26/Transformer_Translation", filename="model.pth")
29
- tokenizer_file = hf_hub_download(repo_id="Kush26/Transformer_Translation", filename="hindi-english_bpe_tokenizer.json")
30
-
31
- tokenizer = Tokenizer.from_file(tokenizer_file)
32
- vocab_size = tokenizer.get_vocab_size()
33
-
34
- config = {
35
- "d_model": 256,
36
- "num_layers": 6,
37
- "num_heads": 8,
38
- "d_ff": 2048,
39
- "dropout": 0.1,
40
- "max_seq_len": 512,
41
- }
42
-
43
- model = BuildTransformer(
44
- src_vocab_size=vocab_size,
45
- trg_vocab_size=vocab_size,
46
- src_seq_len=config["max_seq_len"],
47
- trg_seq_len=config["max_seq_len"],
48
- d_model=config["d_model"],
49
- N=config["num_layers"],
50
- h=config["num_heads"],
51
- dropout=config["dropout"],
52
- d_ff=config["d_ff"]
53
- ).to(device)
54
-
55
- # 5. Load the trained weights
56
- checkpoint = torch.load(model_file, map_location=device)
57
- model.load_state_dict(checkpoint['model_state_dict'])
58
- model.eval() # Set model to evaluation mode
59
-
60
- print("✅ Model and Tokenizer loaded successfully!")
61
-
62
-
63
- def greedy_decode(sentence: str, max_len=100):
64
- PAD_token = tokenizer.token_to_id('[PAD]')
65
-
66
- model.eval()
67
- src_ids = [tokenizer.token_to_id('[SOS]')] + tokenizer.encode(sentence).ids + [tokenizer.token_to_id('[EOS]')]
68
- src_tensor = torch.tensor(src_ids).unsqueeze(0).to(device)
69
- src_mask = (src_tensor != PAD_token).unsqueeze(1).unsqueeze(2)
70
-
71
- with torch.no_grad():
72
- encoder_output = model.encode(src_tensor, src_mask)
73
-
74
- tgt_tokens = [tokenizer.token_to_id('[SOS]')]
75
-
76
- for _ in range(max_len):
77
- tgt_tensor = torch.tensor(tgt_tokens).unsqueeze(0).to(device)
78
- trg_mask_padding = (tgt_tensor != PAD_token).unsqueeze(1).unsqueeze(2)
79
- subsequent_mask = torch.tril(torch.ones(1, tgt_tensor.size(1), tgt_tensor.size(1), device=device)).bool()
80
- trg_mask = trg_mask_padding & subsequent_mask
81
-
82
- with torch.no_grad():
83
- decoder_output = model.decode(encoder_output, src_mask, tgt_tensor, trg_mask)
84
- logits = model.project(decoder_output)
85
-
86
- pred_token = logits.argmax(dim=-1)[0, -1].item()
87
- tgt_tokens.append(pred_token)
88
-
89
- if pred_token == tokenizer.token_to_id('[EOS]'):
90
- break
91
-
92
- return tokenizer.decode(tgt_tokens, skip_special_tokens=True)
93
-
94
- @app.get("/")
95
- def read_root():
96
- return {"message": "Welcome to the Hindi-English Translator API"}
97
-
98
- @app.post("/translate/greedy", response_model=TranslationResponse)
99
- def translate_greedy_endpoint(request: TranslationRequest):
100
-
101
- translated_text = greedy_decode(request.text)
 
102
  return {"translated_text": translated_text}
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from tokenizers import Tokenizer
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ from .model_def import BuildTransformer
9
+
10
+
11
+ app = FastAPI(title="Hindi-English Translator API")
12
+ model = None
13
+ tokenizer = None
14
+ device = torch.device("cpu")
15
+
16
+
17
+ class TranslationRequest(BaseModel):
18
+ text: str
19
+
20
+ class TranslationResponse(BaseModel):
21
+ translated_text: str
22
+
23
+ @app.on_event("startup")
24
+ def load_assets():
25
+
26
+ global model, tokenizer, device
27
+
28
+ local_cache_dir = "./hf_cache"
29
+ model_file = hf_hub_download(repo_id="Kush26/Transformer_Translation", filename="model.pth", cache_dir=local_cache_dir)
30
+ tokenizer_file = hf_hub_download(repo_id="Kush26/Transformer_Translation", filename="hindi-english_bpe_tokenizer.json", cache_dir=local_cache_dir)
31
+
32
+ tokenizer = Tokenizer.from_file(tokenizer_file)
33
+ vocab_size = tokenizer.get_vocab_size()
34
+
35
+ config = {
36
+ "d_model": 256,
37
+ "num_layers": 6,
38
+ "num_heads": 8,
39
+ "d_ff": 2048,
40
+ "dropout": 0.1,
41
+ "max_seq_len": 512,
42
+ }
43
+
44
+ model = BuildTransformer(
45
+ src_vocab_size=vocab_size,
46
+ trg_vocab_size=vocab_size,
47
+ src_seq_len=config["max_seq_len"],
48
+ trg_seq_len=config["max_seq_len"],
49
+ d_model=config["d_model"],
50
+ N=config["num_layers"],
51
+ h=config["num_heads"],
52
+ dropout=config["dropout"],
53
+ d_ff=config["d_ff"]
54
+ ).to(device)
55
+
56
+ # 5. Load the trained weights
57
+ checkpoint = torch.load(model_file, map_location=device)
58
+ model.load_state_dict(checkpoint['model_state_dict'])
59
+ model.eval() # Set model to evaluation mode
60
+
61
+ print("✅ Model and Tokenizer loaded successfully!")
62
+
63
+
64
+ def greedy_decode(sentence: str, max_len=100):
65
+ PAD_token = tokenizer.token_to_id('[PAD]')
66
+
67
+ model.eval()
68
+ src_ids = [tokenizer.token_to_id('[SOS]')] + tokenizer.encode(sentence).ids + [tokenizer.token_to_id('[EOS]')]
69
+ src_tensor = torch.tensor(src_ids).unsqueeze(0).to(device)
70
+ src_mask = (src_tensor != PAD_token).unsqueeze(1).unsqueeze(2)
71
+
72
+ with torch.no_grad():
73
+ encoder_output = model.encode(src_tensor, src_mask)
74
+
75
+ tgt_tokens = [tokenizer.token_to_id('[SOS]')]
76
+
77
+ for _ in range(max_len):
78
+ tgt_tensor = torch.tensor(tgt_tokens).unsqueeze(0).to(device)
79
+ trg_mask_padding = (tgt_tensor != PAD_token).unsqueeze(1).unsqueeze(2)
80
+ subsequent_mask = torch.tril(torch.ones(1, tgt_tensor.size(1), tgt_tensor.size(1), device=device)).bool()
81
+ trg_mask = trg_mask_padding & subsequent_mask
82
+
83
+ with torch.no_grad():
84
+ decoder_output = model.decode(encoder_output, src_mask, tgt_tensor, trg_mask)
85
+ logits = model.project(decoder_output)
86
+
87
+ pred_token = logits.argmax(dim=-1)[0, -1].item()
88
+ tgt_tokens.append(pred_token)
89
+
90
+ if pred_token == tokenizer.token_to_id('[EOS]'):
91
+ break
92
+
93
+ return tokenizer.decode(tgt_tokens, skip_special_tokens=True)
94
+
95
+ @app.get("/")
96
+ def read_root():
97
+ return {"message": "Welcome to the Hindi-English Translator API"}
98
+
99
+ @app.post("/translate/greedy", response_model=TranslationResponse)
100
+ def translate_greedy_endpoint(request: TranslationRequest):
101
+
102
+ translated_text = greedy_decode(request.text)
103
  return {"translated_text": translated_text}