saadkhi commited on
Commit
84031c5
Β·
verified Β·
1 Parent(s): 107fcf0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -29
app.py CHANGED
@@ -1,42 +1,88 @@
 
 
1
  import torch
2
- from fastapi import FastAPI
3
- from pydantic import BaseModel
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
 
6
- MODEL_ID = "saadkhi/SQL_Chat_finetuned_model"
 
 
7
 
8
- app = FastAPI()
 
 
9
 
10
- # ---- LOAD ONCE ONLY ----
11
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
 
 
 
 
 
12
 
13
- model = AutoModelForCausalLM.from_pretrained(
14
- MODEL_ID,
15
- dtype=torch.float16, # use dtype, not torch_dtype
16
- device_map="auto",
17
- low_cpu_mem_usage=True
 
 
 
 
 
18
  )
19
 
20
- model.eval()
 
21
 
 
22
 
23
- class QueryRequest(BaseModel):
24
- prompt: str
25
- max_new_tokens: int = 256
 
 
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
27
 
28
- @app.post("/generate")
29
- def generate(req: QueryRequest):
30
- inputs = tokenizer(req.prompt, return_tensors="pt").to(model.device)
 
 
 
31
 
32
- with torch.no_grad():
33
- outputs = model.generate(
34
- **inputs,
35
- max_new_tokens=req.max_new_tokens,
36
- do_sample=True,
37
- temperature=0.7,
38
- top_p=0.9
39
- )
 
 
 
 
 
 
 
 
 
 
 
40
 
41
- text = tokenizer.decode(outputs[0], skip_special_tokens=True)
42
- return {"response": text}
 
1
+ # app.py - Optimized for Hugging Face Spaces (Unsloth = 2-4x faster)
2
+
3
  import torch
4
+ import gradio as gr
5
+ from unsloth import FastLanguageModel
 
6
 
7
+ # ────────────────────────────────────────────────────────────────
8
+ BASE_MODEL = "unsloth/Phi-3-mini-4k-instruct-bnb-4bit"
9
+ LORA_PATH = "saadkhi/SQL_Chat_finetuned_model"
10
 
11
+ MAX_NEW_TOKENS = 180
12
+ TEMPERATURE = 0.0 # Greedy = fastest & deterministic
13
+ # ────────────────────────────────────────────────────────────────
14
 
15
+ print("Loading base model with Unsloth (4-bit)...")
16
+ model, tokenizer = FastLanguageModel.from_pretrained(
17
+ model_name = BASE_MODEL,
18
+ max_seq_length = 2048,
19
+ dtype = None, # Auto: bfloat16 on GPU
20
+ load_in_4bit = True, # Already quantized base
21
+ )
22
 
23
+ print("Applying your LoRA adapter...")
24
+ model = FastLanguageModel.get_peft_model(
25
+ model,
26
+ r = 64, # Match your original rank
27
+ target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
28
+ "gate_proj", "up_proj", "down_proj"],
29
+ lora_alpha = 128,
30
+ lora_dropout = 0,
31
+ bias = "none",
32
+ use_gradient_checkpointing = "unsloth",
33
  )
34
 
35
+ # Enable 2x faster inference kernels
36
+ FastLanguageModel.for_inference(model)
37
 
38
+ print("Model ready! (very fast now)")
39
 
40
+ # ────────────────────────────────────────────────────────────────
41
+ def generate_sql(prompt: str):
42
+ messages = [{"role": "user", "content": prompt}]
43
+
44
+ inputs = tokenizer.apply_chat_template(
45
+ messages,
46
+ tokenize=True,
47
+ add_generation_prompt=True,
48
+ return_tensors="pt"
49
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
50
 
51
+ outputs = model.generate(
52
+ input_ids = inputs,
53
+ max_new_tokens = MAX_NEW_TOKENS,
54
+ temperature = TEMPERATURE,
55
+ do_sample = (TEMPERATURE > 0.01),
56
+ use_cache = True,
57
+ pad_token_id = tokenizer.eos_token_id,
58
+ )
59
 
60
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
61
+
62
+ # Extract only assistant response
63
+ if "<|assistant|>" in response:
64
+ response = response.split("<|assistant|>", 1)[-1].strip()
65
+ response = response.split("<|end|>")[0].strip()
66
 
67
+ return response
68
+
69
+ # ────────────────────────────────────────────────────────────────
70
+ demo = gr.Interface(
71
+ fn = generate_sql,
72
+ inputs = gr.Textbox(
73
+ label = "Ask SQL question",
74
+ placeholder = "Delete duplicate rows from users table based on email",
75
+ lines = 3
76
+ ),
77
+ outputs = gr.Textbox(label="Generated SQL"),
78
+ title = "SQL Chatbot - Ultra Fast (Unsloth)",
79
+ description = "Phi-3-mini 4-bit + your LoRA",
80
+ examples = [
81
+ ["Find duplicate emails in users table"],
82
+ ["Top 5 highest paid employees"],
83
+ ["Count orders per customer last month"]
84
+ ]
85
+ )
86
 
87
+ if __name__ == "__main__":
88
+ demo.launch()