saadkhi commited on
Commit
52ae0ac
Β·
verified Β·
1 Parent(s): 7cbf352

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -67
app.py CHANGED
@@ -1,94 +1,71 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
- from peft import PeftModel
4
  import torch
5
- import os
6
 
7
- # Create offload folder (very important!)
8
- OFFLOAD_DIR = "offload"
9
- os.makedirs(OFFLOAD_DIR, exist_ok=True)
10
 
11
- # Optimal 4-bit quantization config
12
- quant_config = BitsAndBytesConfig(
 
 
13
  load_in_4bit=True,
14
- bnb_4bit_quant_type="nf4",
15
- bnb_4bit_compute_dtype=torch.bfloat16,
16
- bnb_4bit_use_double_quant=True,
17
  )
18
 
19
- # Model paths
20
- base_model_name = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
21
- lora_model_name = "saadkhi/SQL_Chat_finetuned_model"
22
-
23
- print("Loading base model...")
24
- base_model = AutoModelForCausalLM.from_pretrained(
25
- base_model_name,
26
- quantization_config=quant_config,
27
- device_map="auto",
28
- trust_remote_code=True,
29
- offload_folder=OFFLOAD_DIR, # ← Required fix!
30
- )
31
-
32
- print("Loading LoRA adapter...")
33
- model = PeftModel.from_pretrained(
34
- base_model,
35
- lora_model_name,
36
- offload_folder=OFFLOAD_DIR, # ← Required here too!
37
  )
38
 
39
- tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
40
-
41
- model.eval()
42
  print("Model loaded successfully!")
43
 
44
- def chat(message, history):
45
- # Build conversation in correct Phi-3 format
 
46
  messages = []
47
- for user, assistant in history:
48
- messages.append({"role": "user", "content": user})
49
- if assistant:
50
- messages.append({"role": "assistant", "content": assistant})
51
  messages.append({"role": "user", "content": message})
52
 
 
53
  inputs = tokenizer.apply_chat_template(
54
  messages,
55
  tokenize=True,
56
  add_generation_prompt=True,
57
  return_tensors="pt"
58
- ).to(model.device)
59
 
60
- # Fast generation settings
61
  outputs = model.generate(
62
- inputs,
63
- max_new_tokens=256,
64
- temperature=0.7,
65
- do_sample=True,
66
- top_p=0.9,
67
- repetition_penalty=1.1,
68
  use_cache=True,
69
- eos_token_id=tokenizer.eos_token_id,
70
  )
71
 
72
- response = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True)
73
-
74
- history.append((message, response))
75
- return history, ""
76
 
77
- # Gradio UI
78
- with gr.Blocks(title="SQL Chatbot", theme=gr.themes.Soft()) as demo:
79
- gr.Markdown("# SQL Chat Assistant")
80
- gr.Markdown("Fine-tuned Phi-3 Mini (4-bit) for SQL queries. Responses ~3–10s on GPU.")
81
-
82
- chatbot = gr.Chatbot(height=500)
83
- msg = gr.Textbox(
84
- label="Your Question",
85
- placeholder="e.g., delete duplicate rows from users table based on email",
86
- lines=2
87
- )
88
- clear = gr.Button("Clear")
89
 
90
- msg.submit(chat, [msg, chatbot], [chatbot, msg])
91
- clear.click(lambda: ([], ""), None, chatbot)
92
 
93
- demo.queue(max_size=30)
94
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
2
  import torch
3
+ from unsloth import FastLanguageModel
4
 
5
+ # ── Global model (loaded once at startup) ───────────────────────────────
6
+ print("Loading model...")
 
7
 
8
+ model, tokenizer = FastLanguageModel.from_pretrained(
9
+ "unsloth/Phi-3-mini-4k-instruct-bnb-4bit", # very fast pre-quantized base
10
+ max_seq_length=2048,
11
+ dtype=None, # auto (bf16/float16)
12
  load_in_4bit=True,
 
 
 
13
  )
14
 
15
+ # Load your LoRA adapter
16
+ model = FastLanguageModel.for_inference(
17
+ model.load_adapter("saadkhi/SQL_Chat_finetuned_model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  )
19
 
 
 
 
20
  print("Model loaded successfully!")
21
 
22
+ # ── Chat function ───────────────────────────────────────────────────────
23
+ def generate_response(message, history):
24
+ # Build messages list (multi-turn support)
25
  messages = []
26
+ for user_msg, assistant_msg in history:
27
+ messages.append({"role": "user", "content": user_msg})
28
+ messages.append({"role": "assistant", "content": assistant_msg})
 
29
  messages.append({"role": "user", "content": message})
30
 
31
+ # Use the proper chat template (very important for Phi-3)
32
  inputs = tokenizer.apply_chat_template(
33
  messages,
34
  tokenize=True,
35
  add_generation_prompt=True,
36
  return_tensors="pt"
37
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
38
 
39
+ # Generate
40
  outputs = model.generate(
41
+ input_ids=inputs,
42
+ max_new_tokens=180, # ← increased but still reasonable
43
+ temperature=0.0,
44
+ do_sample=False, # greedy = fastest & most deterministic
 
 
45
  use_cache=True,
 
46
  )
47
 
48
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
49
 
50
+ # Clean up output (remove input prompt part)
51
+ if "<|assistant|>" in response:
52
+ response = response.split("<|assistant|>")[-1].strip()
53
+
54
+ return response
 
 
 
 
 
 
 
55
 
 
 
56
 
57
+ # ── Gradio UI ───────────────────────────────────────────────────────────
58
+ demo = gr.ChatInterface(
59
+ fn=generate_response,
60
+ title="SQL Chat Assistant (Fast Version)",
61
+ description="Ask SQL related questions β€’ Powered by Phi-3-mini + your fine-tune",
62
+ examples=[
63
+ "Write a query to find duplicate emails in users table",
64
+ "How to delete rows with NULL values in column price?",
65
+ "Select top 10 most expensive products",
66
+ ],
67
+ cache_examples=False,
68
+ )
69
+
70
+ if __name__ == "__main__":
71
+ demo.launch()