hello-ram commited on
Commit
ae5b614
·
verified ·
1 Parent(s): d398613

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -5,36 +5,38 @@ import torch
5
 
6
  app = FastAPI()
7
 
8
- # ---- Load your HF model repo ----
9
  MODEL_REPO = "hello-ram/mpt-model"
10
 
11
- print("Loading tokenizer...")
12
- tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
13
 
14
- print("Loading model...")
15
- model = AutoModelForCausalLM.from_pretrained(
16
- MODEL_REPO,
17
- torch_dtype=torch.float16,
18
- device_map="auto"
19
- )
20
 
21
- # ---------- ROUTES -------------
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  @app.get("/")
24
  async def root():
25
  return {
26
- "message": "🚀 FastAPI MPT Model Running on Hugging Face Spaces",
27
- "endpoints": ["/", "/status", "/generate"]
28
  }
29
 
 
30
  @app.get("/status")
31
  async def status():
32
- return {
33
- "status": "ok",
34
- "model": MODEL_REPO,
35
- "device": str(model.device),
36
- "torch_dtype": str(model.dtype)
37
- }
38
 
39
 
40
  class InputText(BaseModel):
@@ -43,8 +45,9 @@ class InputText(BaseModel):
43
 
44
  @app.post("/generate")
45
  async def generate_text(data: InputText):
46
- inputs = tokenizer(data.text, return_tensors="pt").to(model.device)
47
 
 
48
  output = model.generate(
49
  **inputs,
50
  max_new_tokens=200,
 
5
 
6
  app = FastAPI()
7
 
 
8
  MODEL_REPO = "hello-ram/mpt-model"
9
 
10
+ tokenizer = None
11
+ model = None
12
 
 
 
 
 
 
 
13
 
14
+ def load_model():
15
+ global tokenizer, model
16
+ if tokenizer is None:
17
+ print("Loading tokenizer...")
18
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO)
19
+
20
+ if model is None:
21
+ print("Loading model...")
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ MODEL_REPO,
24
+ dtype=torch.float16,
25
+ device_map="auto"
26
+ )
27
+
28
 
29
  @app.get("/")
30
  async def root():
31
  return {
32
+ "message": "🚀 FastAPI MPT Model Running",
33
+ "routes": ["/", "/status", "/generate"]
34
  }
35
 
36
+
37
  @app.get("/status")
38
  async def status():
39
+ return {"status": "ok"}
 
 
 
 
 
40
 
41
 
42
  class InputText(BaseModel):
 
45
 
46
  @app.post("/generate")
47
  async def generate_text(data: InputText):
48
+ load_model()
49
 
50
+ inputs = tokenizer(data.text, return_tensors="pt").to(model.device)
51
  output = model.generate(
52
  **inputs,
53
  max_new_tokens=200,