saadkhi commited on
Commit
1344c31
Β·
verified Β·
1 Parent(s): c6ed9f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -27
app.py CHANGED
@@ -1,30 +1,96 @@
 
1
  import torch
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- from peft import PeftModel
5
- from transformers import BitsAndBytesConfig
6
- device = "cuda" if torch.cuda.is_available() else "cpu"
7
- base_model = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
8
- finetuned_model = "saadkhi/SQL_Chat_finetuned_model"
9
- tokenizer = AutoTokenizer.from_pretrained(base_model)
10
- bnb = BitsAndBytesConfig(load_in_4bit=True)
11
- model = AutoModelForCausalLM.from_pretrained(
12
- Β Β Β Β base_model,
13
- Β Β Β Β quantization_config=bnb,
14
- Β Β Β Β torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
15
- Β Β Β Β device_map="auto"
 
 
 
 
16
  )
17
- model = PeftModel.from_pretrained(model, finetuned_model).to(device)
18
- model.eval()
19
- def chat(prompt):
20
- Β Β Β Β inputs = tokenizer(prompt, return_tensors="pt").to(device)
21
- Β Β Β Β with torch.inference_mode():
22
- Β Β Β Β Β Β Β Β output = model.generate(
23
- Β Β Β Β Β Β Β Β Β Β Β Β **inputs,
24
- Β Β Β Β Β Β Β Β Β Β Β Β max_new_tokens=60,
25
- Β Β Β Β Β Β Β Β Β Β Β Β temperature=0.1,
26
- Β Β Β Β Β Β Β Β Β Β Β Β do_sample=False
27
- Β Β Β Β Β Β Β Β )
28
- Β Β Β Β return tokenizer.decode(output[0], skip_special_tokens=True)
29
- iface = gr.Interface(fn=chat, inputs="text", outputs="text", title="SQL Chatbot")
30
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
  import torch
3
  import gradio as gr
4
+ from unsloth import FastLanguageModel
5
+
6
+ # ────────────────────────────────────────────────────────────────
7
+ # Configuration - change here if needed
8
+ # ────────────────────────────────────────────────────────────────
9
+ MAX_NEW_TOKENS = 96
10
+ TEMPERATURE = 0.0 # 0.0 = greedy decoding = fastest
11
+ BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
12
+ LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
13
+
14
+ # ────────────────────────────────────────────────────────────────
15
+ print("Loading model with Unsloth...")
16
+ model, tokenizer = FastLanguageModel.from_pretrained(
17
+ model_name=BASE_MODEL,
18
+ max_seq_length=2048,
19
+ dtype=None, # auto-detect (bf16 on GPU)
20
+ load_in_4bit=True,
21
  )
22
+
23
+ print("Loading LoRA adapters...")
24
+ model = FastLanguageModel.get_peft_model(
25
+ model,
26
+ r=64, # your original rank
27
+ target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
28
+ lora_alpha=128,
29
+ lora_dropout=0,
30
+ bias="none",
31
+ use_gradient_checkpointing="unsloth",
32
+ )
33
+
34
+ print("Merging LoRA and preparing for inference...")
35
+ model = FastLanguageModel.for_inference(model) # important! activates 2x faster kernels
36
+
37
+ # Optional - compile can give additional 20-60% speedup (PyTorch 2.0+)
38
+ if torch.cuda.is_available() and torch.__version__ >= "2.0":
39
+ print("Compiling model...")
40
+ model = torch.compile(model, mode="reduce-overhead")
41
+
42
+ print("Model ready!")
43
+
44
+ # ────────────────────────────────────────────────────────────────
45
+ def generate_sql(prompt: str):
46
+ # Very clean chat template usage
47
+ messages = [{"role": "user", "content": prompt}]
48
+
49
+ inputs = tokenizer.apply_chat_template(
50
+ messages,
51
+ tokenize=True,
52
+ add_generation_prompt=True,
53
+ return_tensors="pt"
54
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
55
+
56
+ outputs = model.generate(
57
+ input_ids=inputs,
58
+ max_new_tokens=MAX_NEW_TOKENS,
59
+ temperature=TEMPERATURE,
60
+ do_sample=(TEMPERATURE > 0.01),
61
+ use_cache=True,
62
+ pad_token_id=tokenizer.eos_token_id,
63
+ )
64
+
65
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
66
+
67
+ # Try to cut after assistant's answer
68
+ if "<|assistant|>" in response:
69
+ response = response.split("<|assistant|>", 1)[-1].strip()
70
+ if "<|end|>" in response:
71
+ response = response.split("<|end|>")[0].strip()
72
+
73
+ return response
74
+
75
+
76
+ # ────────────────────────────────────────────────────────────────
77
+ demo = gr.Interface(
78
+ fn=generate_sql,
79
+ inputs=gr.Textbox(
80
+ label="Ask SQL related question",
81
+ placeholder="Show me all employees with salary > 50000...",
82
+ lines=3,
83
+ ),
84
+ outputs=gr.Textbox(label="Generated SQL / Answer"),
85
+ title="SQL Chat Assistant (Phi-3-mini fine-tuned)",
86
+ description="Fast version using Unsloth",
87
+ examples=[
88
+ ["Find all duplicate emails in users table"],
89
+ ["Get top 5 highest paid employees"],
90
+ ["How many orders per customer last month?"],
91
+ ],
92
+ allow_flagging="never",
93
+ )
94
+
95
+ if __name__ == "__main__":
96
+ demo.launch()