Rugs25 commited on
Commit
1ffc0c0
·
verified ·
1 Parent(s): c400548

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -6
app.py CHANGED
@@ -1,22 +1,29 @@
1
  from fastapi import FastAPI, Form
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
 
4
 
5
  app = FastAPI()
6
 
7
- # Load NLLB model once on startup
8
  MODEL_NAME = "facebook/nllb-200-distilled-600M"
9
  SRC_LANG = "mar_Deva" # Marathi
10
  TGT_LANG = "eng_Latn" # English
11
 
12
- print("Loading model... This may take a minute.")
13
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
15
 
16
- tokenizer.src_lang = SRC_LANG
 
 
 
 
 
 
 
17
 
18
  @app.post("/translate")
19
  async def translate(text: str = Form(...)):
 
20
  inputs = tokenizer(text, return_tensors="pt")
21
  with torch.no_grad():
22
  generated_tokens = model.generate(
@@ -29,4 +36,9 @@ async def translate(text: str = Form(...)):
29
 
30
  @app.get("/")
31
  def root():
32
- return {"status": "NLLB API is running!"}
 
 
 
 
 
 
1
  from fastapi import FastAPI, Form
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
+ import os
5
 
6
  app = FastAPI()
7
 
 
8
  MODEL_NAME = "facebook/nllb-200-distilled-600M"
9
  SRC_LANG = "mar_Deva" # Marathi
10
  TGT_LANG = "eng_Latn" # English
11
 
12
+ tokenizer = None
13
+ model = None
 
14
 
15
+ def load_model():
16
+ global tokenizer, model
17
+ if tokenizer is None or model is None:
18
+ print("Loading NLLB model... This may take a minute.")
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
20
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
21
+ tokenizer.src_lang = SRC_LANG
22
+ print("Model loaded successfully!")
23
 
24
  @app.post("/translate")
25
  async def translate(text: str = Form(...)):
26
+ load_model()
27
  inputs = tokenizer(text, return_tensors="pt")
28
  with torch.no_grad():
29
  generated_tokens = model.generate(
 
36
 
37
  @app.get("/")
38
  def root():
39
+ return {"status": "NLLB API is running (model loads on first request)!"}
40
+
41
+ if __name__ == "__main__":
42
+ import uvicorn
43
+ port = int(os.environ.get("PORT", 8000))
44
+ uvicorn.run(app, host="0.0.0.0", port=port)