saadkhi commited on
Commit
c38fb83
·
verified ·
1 Parent(s): c2d5c36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -11
app.py CHANGED
@@ -52,7 +52,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
52
  from peft import PeftModel
53
  import torch
54
 
55
- # Quantization config for fast 4-bit loading
56
  quant_config = BitsAndBytesConfig(
57
  load_in_4bit=True,
58
  bnb_4bit_quant_type="nf4",
@@ -60,17 +60,17 @@ quant_config = BitsAndBytesConfig(
60
  bnb_4bit_use_double_quant=True,
61
  )
62
 
63
- # Load base model + your LoRA once at startup
64
  base_model_name = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
65
  lora_model_name = "saadkhi/SQL_Chat_finetuned_model"
66
 
67
- print("Loading model (20–40 seconds first time)...")
68
  base_model = AutoModelForCausalLM.from_pretrained(
69
  base_model_name,
70
  quantization_config=quant_config,
71
  device_map="auto",
72
  trust_remote_code=True,
73
- attn_implementation="flash_attention_2", # Fastest on T4/A10G
74
  )
75
 
76
  model = PeftModel.from_pretrained(base_model, lora_model_name)
@@ -80,7 +80,7 @@ model.eval()
80
  print("Model ready!")
81
 
82
  def chat(message, history):
83
- # Build full conversation history in Phi-3 format
84
  messages = []
85
  for user, assistant in history:
86
  messages.append({"role": "user", "content": user})
@@ -88,7 +88,6 @@ def chat(message, history):
88
  messages.append({"role": "assistant", "content": assistant})
89
  messages.append({"role": "user", "content": message})
90
 
91
- # Tokenize with chat template
92
  inputs = tokenizer.apply_chat_template(
93
  messages,
94
  tokenize=True,
@@ -96,7 +95,7 @@ def chat(message, history):
96
  return_tensors="pt"
97
  ).to(model.device)
98
 
99
- # Generate with optimal settings
100
  outputs = model.generate(
101
  inputs,
102
  max_new_tokens=256,
@@ -104,20 +103,19 @@ def chat(message, history):
104
  do_sample=True,
105
  top_p=0.9,
106
  repetition_penalty=1.1,
107
- use_cache=True, # KV caching = much faster
108
  eos_token_id=tokenizer.eos_token_id,
109
  )
110
 
111
- # Decode only the new response
112
  response = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
113
 
114
  history.append((message, response))
115
  return history, ""
116
 
117
- # Gradio interface
118
  with gr.Blocks(title="SQL Chatbot", theme=gr.themes.Soft()) as demo:
119
  gr.Markdown("# SQL Chat Assistant")
120
- gr.Markdown("Fine-tuned Phi-3 Mini for SQL queries. Responses in 26 seconds on GPU.")
121
 
122
  chatbot = gr.Chatbot(height=500)
123
  msg = gr.Textbox(label="Your Question", placeholder="e.g., delete duplicate rows from users table based on email", lines=2)
 
52
  from peft import PeftModel
53
  import torch
54
 
55
+ # Best 4-bit config for speed + low memory
56
  quant_config = BitsAndBytesConfig(
57
  load_in_4bit=True,
58
  bnb_4bit_quant_type="nf4",
 
60
  bnb_4bit_use_double_quant=True,
61
  )
62
 
63
+ # Load base + your LoRA once
64
  base_model_name = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
65
  lora_model_name = "saadkhi/SQL_Chat_finetuned_model"
66
 
67
+ print("Loading model (20–40s first time)...")
68
  base_model = AutoModelForCausalLM.from_pretrained(
69
  base_model_name,
70
  quantization_config=quant_config,
71
  device_map="auto",
72
  trust_remote_code=True,
73
+ # Removed flash_attention_2 — avoids install issues
74
  )
75
 
76
  model = PeftModel.from_pretrained(base_model, lora_model_name)
 
80
  print("Model ready!")
81
 
82
  def chat(message, history):
83
+ # Full conversation history
84
  messages = []
85
  for user, assistant in history:
86
  messages.append({"role": "user", "content": user})
 
88
  messages.append({"role": "assistant", "content": assistant})
89
  messages.append({"role": "user", "content": message})
90
 
 
91
  inputs = tokenizer.apply_chat_template(
92
  messages,
93
  tokenize=True,
 
95
  return_tensors="pt"
96
  ).to(model.device)
97
 
98
+ # Optimized generation
99
  outputs = model.generate(
100
  inputs,
101
  max_new_tokens=256,
 
103
  do_sample=True,
104
  top_p=0.9,
105
  repetition_penalty=1.1,
106
+ use_cache=True, # KV cache = faster sequential tokens
107
  eos_token_id=tokenizer.eos_token_id,
108
  )
109
 
 
110
  response = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
111
 
112
  history.append((message, response))
113
  return history, ""
114
 
115
+ # UI
116
  with gr.Blocks(title="SQL Chatbot", theme=gr.themes.Soft()) as demo:
117
  gr.Markdown("# SQL Chat Assistant")
118
+ gr.Markdown("Fine-tuned Phi-3 Mini for SQL. Fast responses (38s on GPU).")
119
 
120
  chatbot = gr.Chatbot(height=500)
121
  msg = gr.Textbox(label="Your Question", placeholder="e.g., delete duplicate rows from users table based on email", lines=2)