saadkhi commited on
Commit
0ddc005
·
verified ·
1 Parent(s): a663164

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -45
app.py CHANGED
@@ -47,62 +47,71 @@
47
 
48
 
49
 
50
- import torch
51
  import gradio as gr
52
- from transformers import AutoTokenizer, AutoModelForCausalLM
53
- from peft import PeftModel
54
- from transformers import BitsAndBytesConfig
55
-
56
- device = "cuda" if torch.cuda.is_available() else "cpu"
57
 
58
- base_model = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
59
- finetuned_model = "saadkhi/SQL_Chat_finetuned_model"
 
 
 
 
 
60
 
61
- tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
 
 
 
 
62
 
63
- bnb_config = BitsAndBytesConfig(load_in_4bit=True)
 
64
 
65
- model = AutoModelForCausalLM.from_pretrained(
66
- base_model,
67
- quantization_config=bnb_config,
68
- torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
69
- device_map="auto",
70
- trust_remote_code=True,
71
- )
72
- model = PeftModel.from_pretrained(model, finetuned_model)
73
- model.eval()
74
 
75
- def chat(user_prompt):
76
- # Proper Phi-3 chat format
77
- messages = [{"role": "user", "content": user_prompt}]
78
-
79
  inputs = tokenizer.apply_chat_template(
80
  messages,
81
  tokenize=True,
82
  add_generation_prompt=True,
83
  return_tensors="pt"
84
- ).to(device)
85
-
86
- with torch.inference_mode():
87
- outputs = model.generate(
88
- inputs,
89
- max_new_tokens=256, # Increased a bit for full SQL
90
- temperature=0.7,
91
- do_sample=True, # Better for creativity, faster
92
- top_p=0.9,
93
- repetition_penalty=1.1,
94
- )
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- # Clean response
97
- response = tokenizer.decode(outputs[0], skip_special_tokens=False)
98
- response = response.split("<|assistant|>")[-1].split("<|end|>")[0].strip()
99
-
100
- return response
101
 
102
- iface = gr.ChatInterface(
103
- fn=chat,
104
- title="Fast SQL Chatbot",
105
- description="Ask SQL questions (e.g., 'delete duplicate rows based on email')"
106
- )
107
 
108
- iface.launch()
 
 
47
 
48
 
49
 
 
50
  import gradio as gr
51
+ from unsloth import FastLanguageModel
52
+ import torch
 
 
 
53
 
54
+ # Load model once at startup — Unsloth makes it 2.5x faster
55
+ model, tokenizer = FastLanguageModel.from_pretrained(
56
+ model_name="unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
57
+ max_seq_length=4096,
58
+ dtype=None, # Auto detect (bfloat16 if supported)
59
+ load_in_4bit=True,
60
+ )
61
 
62
+ # Load your fine-tuned LoRA adapter
63
+ model = FastLanguageModel.get_peft_model(
64
+ model,
65
+ "saadkhi/SQL_Chat_finetuned_model", # Your HF repo
66
+ )
67
 
68
+ # Enable fast inference mode (critical for speed!)
69
+ FastLanguageModel.for_inference(model)
70
 
71
+ def chat(message, history):
72
+ # Build proper Phi-3 chat format
73
+ messages = []
74
+ for user_msg, bot_msg in history:
75
+ messages.append({"role": "user", "content": user_msg})
76
+ if bot_msg:
77
+ messages.append({"role": "assistant", "content": bot_msg})
78
+ messages.append({"role": "user", "content": message})
 
79
 
80
+ # Apply chat template and tokenize
 
 
 
81
  inputs = tokenizer.apply_chat_template(
82
  messages,
83
  tokenize=True,
84
  add_generation_prompt=True,
85
  return_tensors="pt"
86
+ ).to(model.device)
87
+
88
+ # Generate fast
89
+ output = model.generate(
90
+ input_ids=inputs,
91
+ max_new_tokens=256,
92
+ temperature=0.7,
93
+ do_sample=True,
94
+ top_p=0.9,
95
+ use_cache=True,
96
+ repetition_penalty=1.1,
97
+ )
98
+
99
+ # Decode only the new part
100
+ response = tokenizer.decode(output[0][inputs.shape[-1]:], skip_special_tokens=True)
101
+ history.append((message, response))
102
+ return history, ""
103
+
104
+ # Clean Gradio Chat Interface
105
+ with gr.Blocks(title="SQL Chatbot", theme=gr.themes.Soft()) as demo:
106
+ gr.Markdown("# SQL Chat Assistant")
107
+ gr.Markdown("Ask any SQL-related question. Fast responses powered by fine-tuned Phi-3 Mini.")
108
 
109
+ chatbot = gr.Chatbot(height=500)
110
+ msg = gr.Textbox(label="Your Message", placeholder="e.g., delete duplicate rows from users table", lines=2)
111
+ clear = gr.Button("Clear")
 
 
112
 
113
+ msg.submit(chat, [msg, chatbot], [chatbot, msg])
114
+ clear.click(lambda: ([], ""), None, chatbot)
 
 
 
115
 
116
+ demo.queue(max_size=20) # Handle multiple users smoothly
117
+ demo.launch()