saadkhi commited on
Commit
c6fae16
Β·
verified Β·
1 Parent(s): 3940292

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -87
app.py CHANGED
@@ -1,103 +1,47 @@
1
- # app.py - Fixed for recent Gradio versions (no allow_flagging)
2
-
3
  import torch
4
- import gradio as gr
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
6
- from peft import PeftModel
7
-
8
- # ────────────────────────────────────────────────────────────────
9
- # Fastest practical configuration
10
- # ────────────────────────────────────────────────────────────────
11
-
12
- BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
13
- LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
14
 
15
- MAX_NEW_TOKENS = 180
16
- TEMPERATURE = 0.0 # greedy = fastest
17
- DO_SAMPLE = False
18
 
19
- # ────────────────────────────────────────────────────────────────
20
- # 4-bit quantization (very important for speed)
21
- # ────────────────────────────────────────────────────────────────
22
 
23
- bnb_config = BitsAndBytesConfig(
24
- load_in_4bit = True,
25
- bnb_4bit_quant_type = "nf4",
26
- bnb_4bit_use_double_quant = True,
27
- bnb_4bit_compute_dtype = torch.bfloat16
28
- )
29
-
30
- print("Loading quantized base model...")
31
  model = AutoModelForCausalLM.from_pretrained(
32
- BASE_MODEL,
33
- quantization_config = bnb_config,
34
- device_map = "auto",
35
- trust_remote_code = True,
36
- torch_dtype = torch.bfloat16
37
  )
38
 
39
- print("Loading LoRA adapters...")
40
- model = PeftModel.from_pretrained(model, LORA_PATH)
41
-
42
- # Merge LoRA into base model β†’ much faster inference
43
- model = model.merge_and_unload()
44
 
45
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
 
46
 
47
- model.eval()
48
- print("Model ready!")
49
 
50
- # ────────────────────────────────────────────────────────────────
51
- def generate_sql(prompt: str):
52
- messages = [{"role": "user", "content": prompt}]
53
-
54
- inputs = tokenizer.apply_chat_template(
55
- messages,
56
- tokenize=True,
57
- add_generation_prompt=True,
58
  return_tensors="pt"
59
  ).to(model.device)
60
 
61
- with torch.inference_mode():
62
- outputs = model.generate(
63
- input_ids = inputs,
64
- max_new_tokens = MAX_NEW_TOKENS,
65
- temperature = TEMPERATURE,
66
- do_sample = DO_SAMPLE,
67
- use_cache = True,
68
- pad_token_id = tokenizer.eos_token_id,
69
  )
70
 
71
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
72
-
73
- # Clean output
74
- if "<|assistant|>" in response:
75
- response = response.split("<|assistant|>", 1)[-1].strip()
76
- response = response.split("<|end|>")[0].strip() if "<|end|>" in response else response
77
-
78
- return response
79
-
80
- # ────────────────────────────────────────────────────────────────
81
- # Gradio interface - modern style (no allow_flagging)
82
- # ────────────────────────────────────────────────────────────────
83
-
84
- demo = gr.Interface(
85
- fn=generate_sql,
86
- inputs=gr.Textbox(
87
- label="Ask SQL related question",
88
- placeholder="Show me all employees with salary > 50000...",
89
- lines=3
90
- ),
91
- outputs=gr.Textbox(label="Generated SQL / Answer"),
92
- title="SQL Chatbot - Optimized",
93
- description="Phi-3-mini 4bit + LoRA merged",
94
- examples=[
95
- ["Find duplicate emails in users table"],
96
- ["Top 5 highest paid employees"],
97
- ["Count orders per customer last month"]
98
- ],
99
- # flag button is disabled by default in newer versions β†’ no need for allow_flagging
100
- )
101
 
102
- if __name__ == "__main__":
103
- demo.launch()
 
 
 
1
  import torch
2
+ from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
 
 
 
 
5
 
6
+ MODEL_ID = "saadkhi/SQL_Chat_finetuned_model"
 
 
7
 
8
+ app = FastAPI(title="SQL Chatbot API")
 
 
9
 
10
+ # Load model once (on startup)
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
 
 
 
12
  model = AutoModelForCausalLM.from_pretrained(
13
+ MODEL_ID,
14
+ torch_dtype=torch.float16,
15
+ device_map="auto"
 
 
16
  )
17
 
18
+ class QueryRequest(BaseModel):
19
+ prompt: str
20
+ max_new_tokens: int = 256
 
 
21
 
22
+ class QueryResponse(BaseModel):
23
+ response: str
24
 
 
 
25
 
26
+ @app.post("/generate", response_model=QueryResponse)
27
+ def generate_answer(request: QueryRequest):
28
+ inputs = tokenizer(
29
+ request.prompt,
 
 
 
 
30
  return_tensors="pt"
31
  ).to(model.device)
32
 
33
+ with torch.no_grad():
34
+ output_ids = model.generate(
35
+ **inputs,
36
+ max_new_tokens=request.max_new_tokens,
37
+ do_sample=True,
38
+ temperature=0.7,
39
+ top_p=0.9
 
40
  )
41
 
42
+ output_text = tokenizer.decode(
43
+ output_ids[0],
44
+ skip_special_tokens=True
45
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ return {"response": output_text}