Mayur74 commited on
Commit
a350c84
·
verified ·
1 Parent(s): 728b6a3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import torch
6
+ import os
7
+
8
+ MODEL_ID = "Mayur74/tinyllama-finetune-mayur" # or local folder name if you upload merged_model/
9
+ USE_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # If your model is private, the token should be set via Secrets (HF_SECRETS)
12
+ HF_TOKEN = os.environ.get("HF_TOKEN") # set this in Space Settings -> Secrets (recommended for private models)
13
+
14
+ app = FastAPI(title="TinyLlama Inference")
15
+
16
+ class GenRequest(BaseModel):
17
+ prompt: str
18
+ max_new_tokens: int = 128
19
+ temperature: float = 0.7
20
+
21
+ # Load tokenizer & model once at startup
22
+ def load_model():
23
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN)
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ MODEL_ID,
26
+ torch_dtype=torch.float16 if USE_DEVICE=="cuda" else torch.float32,
27
+ device_map="auto" if USE_DEVICE=="cuda" else None,
28
+ trust_remote_code=False, # change if model needs trust_remote_code
29
+ use_auth_token=HF_TOKEN
30
+ )
31
+ # reduce cache if needed
32
+ model.eval()
33
+ return tokenizer, model
34
+
35
+ tokenizer, model = load_model()
36
+
37
+ @app.post("/generate")
38
+ async def generate(req: GenRequest):
39
+ try:
40
+ inputs = tokenizer(req.prompt, return_tensors="pt", truncation=True)
41
+ if torch.cuda.is_available():
42
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
43
+
44
+ with torch.no_grad():
45
+ out = model.generate(
46
+ **inputs,
47
+ max_new_tokens=req.max_new_tokens,
48
+ temperature=req.temperature,
49
+ do_sample=True,
50
+ top_p=0.95,
51
+ pad_token_id=tokenizer.eos_token_id,
52
+ )
53
+ text = tokenizer.decode(out[0], skip_special_tokens=True)
54
+ return {"generated_text": text}
55
+ except Exception as e:
56
+ raise HTTPException(status_code=500, detail=str(e))