saadkhi commited on
Commit
107fcf0
·
verified ·
1 Parent(s): 0d17181

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -21
app.py CHANGED
@@ -5,43 +5,38 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
  MODEL_ID = "saadkhi/SQL_Chat_finetuned_model"
7
 
8
- app = FastAPI(title="SQL Chatbot API")
9
 
10
- # Load model once (on startup)
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
12
  model = AutoModelForCausalLM.from_pretrained(
13
  MODEL_ID,
14
- torch_dtype=torch.float16,
15
- device_map="auto"
 
16
  )
17
 
 
 
 
18
  class QueryRequest(BaseModel):
19
  prompt: str
20
  max_new_tokens: int = 256
21
 
22
- class QueryResponse(BaseModel):
23
- response: str
24
 
25
-
26
- @app.post("/generate", response_model=QueryResponse)
27
- def generate_answer(request: QueryRequest):
28
- inputs = tokenizer(
29
- request.prompt,
30
- return_tensors="pt"
31
- ).to(model.device)
32
 
33
  with torch.no_grad():
34
- output_ids = model.generate(
35
  **inputs,
36
- max_new_tokens=request.max_new_tokens,
37
  do_sample=True,
38
  temperature=0.7,
39
  top_p=0.9
40
  )
41
 
42
- output_text = tokenizer.decode(
43
- output_ids[0],
44
- skip_special_tokens=True
45
- )
46
-
47
- return {"response": output_text}
 
5
 
6
  MODEL_ID = "saadkhi/SQL_Chat_finetuned_model"
7
 
8
+ app = FastAPI()
9
 
10
+ # ---- LOAD ONCE ONLY ----
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
+
13
  model = AutoModelForCausalLM.from_pretrained(
14
  MODEL_ID,
15
+ dtype=torch.float16, # use dtype, not torch_dtype
16
+ device_map="auto",
17
+ low_cpu_mem_usage=True
18
  )
19
 
20
+ model.eval()
21
+
22
+
23
  class QueryRequest(BaseModel):
24
  prompt: str
25
  max_new_tokens: int = 256
26
 
 
 
27
 
28
+ @app.post("/generate")
29
+ def generate(req: QueryRequest):
30
+ inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
 
 
 
 
31
 
32
  with torch.no_grad():
33
+ outputs = model.generate(
34
  **inputs,
35
+ max_new_tokens=req.max_new_tokens,
36
  do_sample=True,
37
  temperature=0.7,
38
  top_p=0.9
39
  )
40
 
41
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
42
+ return {"response": text}