saadkhi commited on
Commit
43c048b
·
verified ·
1 Parent(s): 5c843a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -28
app.py CHANGED
@@ -48,36 +48,47 @@
48
 
49
 
50
  import gradio as gr
51
- from unsloth import FastLanguageModel
 
52
  import torch
53
 
54
- # Load model once at startup — Unsloth makes it 2.5x faster
55
- model, tokenizer = FastLanguageModel.from_pretrained(
56
- model_name="unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
57
- max_seq_length=4096,
58
- dtype=None, # Auto detect (bfloat16 if supported)
59
  load_in_4bit=True,
 
 
 
60
  )
61
 
62
- # Load your fine-tuned LoRA adapter
63
- model = FastLanguageModel.get_peft_model(
64
- model,
65
- "saadkhi/SQL_Chat_finetuned_model", # Your HF repo
 
 
 
 
 
 
 
66
  )
67
 
68
- # Enable fast inference mode (critical for speed!)
69
- FastLanguageModel.for_inference(model)
 
 
 
70
 
71
  def chat(message, history):
72
- # Build proper Phi-3 chat format
73
  messages = []
74
- for user_msg, bot_msg in history:
75
- messages.append({"role": "user", "content": user_msg})
76
- if bot_msg:
77
- messages.append({"role": "assistant", "content": bot_msg})
78
  messages.append({"role": "user", "content": message})
79
 
80
- # Apply chat template and tokenize
81
  inputs = tokenizer.apply_chat_template(
82
  messages,
83
  tokenize=True,
@@ -85,33 +96,35 @@ def chat(message, history):
85
  return_tensors="pt"
86
  ).to(model.device)
87
 
88
- # Generate fast
89
- output = model.generate(
90
- input_ids=inputs,
91
  max_new_tokens=256,
92
  temperature=0.7,
93
  do_sample=True,
94
  top_p=0.9,
95
- use_cache=True,
96
  repetition_penalty=1.1,
 
 
97
  )
98
 
99
- # Decode only the new part
100
- response = tokenizer.decode(output[0][inputs.shape[-1]:], skip_special_tokens=True)
 
101
  history.append((message, response))
102
  return history, ""
103
 
104
- # Clean Gradio Chat Interface
105
  with gr.Blocks(title="SQL Chatbot", theme=gr.themes.Soft()) as demo:
106
  gr.Markdown("# SQL Chat Assistant")
107
- gr.Markdown("Ask any SQL-related question. Fast responses powered by fine-tuned Phi-3 Mini.")
108
 
109
  chatbot = gr.Chatbot(height=500)
110
- msg = gr.Textbox(label="Your Message", placeholder="e.g., delete duplicate rows from users table", lines=2)
111
  clear = gr.Button("Clear")
112
 
113
  msg.submit(chat, [msg, chatbot], [chatbot, msg])
114
  clear.click(lambda: ([], ""), None, chatbot)
115
 
116
- demo.queue(max_size=20) # Handle multiple users smoothly
117
  demo.launch()
 
48
 
49
 
50
  import gradio as gr
51
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
52
+ from peft import PeftModel
53
  import torch
54
 
55
+ # Quantization config for fast 4-bit loading
56
+ quant_config = BitsAndBytesConfig(
 
 
 
57
  load_in_4bit=True,
58
+ bnb_4bit_quant_type="nf4",
59
+ bnb_4bit_compute_dtype=torch.bfloat16,
60
+ bnb_4bit_use_double_quant=True,
61
  )
62
 
63
+ # Load base model + your LoRA once at startup
64
+ base_model_name = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
65
+ lora_model_name = "saadkhi/SQL_Chat_finetuned_model"
66
+
67
+ print("Loading model (20–40 seconds first time)...")
68
+ base_model = AutoModelForCausalLM.from_pretrained(
69
+ base_model_name,
70
+ quantization_config=quant_config,
71
+ device_map="auto",
72
+ trust_remote_code=True,
73
+ attn_implementation="flash_attention_2", # Fastest on T4/A10G
74
  )
75
 
76
+ model = PeftModel.from_pretrained(base_model, lora_model_name)
77
+ tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
78
+
79
+ model.eval()
80
+ print("Model ready!")
81
 
82
  def chat(message, history):
83
+ # Build full conversation history in Phi-3 format
84
  messages = []
85
+ for user, assistant in history:
86
+ messages.append({"role": "user", "content": user})
87
+ if assistant:
88
+ messages.append({"role": "assistant", "content": assistant})
89
  messages.append({"role": "user", "content": message})
90
 
91
+ # Tokenize with chat template
92
  inputs = tokenizer.apply_chat_template(
93
  messages,
94
  tokenize=True,
 
96
  return_tensors="pt"
97
  ).to(model.device)
98
 
99
+ # Generate with optimal settings
100
+ outputs = model.generate(
101
+ inputs,
102
  max_new_tokens=256,
103
  temperature=0.7,
104
  do_sample=True,
105
  top_p=0.9,
 
106
  repetition_penalty=1.1,
107
+ use_cache=True, # KV caching = much faster
108
+ eos_token_id=tokenizer.eos_token_id,
109
  )
110
 
111
+ # Decode only the new response
112
+ response = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
113
+
114
  history.append((message, response))
115
  return history, ""
116
 
117
+ # Gradio interface
118
  with gr.Blocks(title="SQL Chatbot", theme=gr.themes.Soft()) as demo:
119
  gr.Markdown("# SQL Chat Assistant")
120
+ gr.Markdown("Fine-tuned Phi-3 Mini for SQL queries. Responses in 2–6 seconds on GPU.")
121
 
122
  chatbot = gr.Chatbot(height=500)
123
+ msg = gr.Textbox(label="Your Question", placeholder="e.g., delete duplicate rows from users table based on email", lines=2)
124
  clear = gr.Button("Clear")
125
 
126
  msg.submit(chat, [msg, chatbot], [chatbot, msg])
127
  clear.click(lambda: ([], ""), None, chatbot)
128
 
129
+ demo.queue(max_size=30)
130
  demo.launch()