File size: 2,656 Bytes
d25573b
 
 
 
 
3a92db0
d25573b
3a92db0
 
d25573b
 
 
3a92db0
d25573b
 
 
 
 
 
3a92db0
d25573b
 
 
3a92db0
d25573b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a92db0
d25573b
 
 
 
3a92db0
d25573b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a92db0
d25573b
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel


BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
LORA_ADAPTER_ID = "adamabuhamdan/tinyllama-sql-lora" 


tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)


base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    torch_dtype=torch.float32,
    device_map="cpu"
)


model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_ID)
model.eval()


def generate_sql(schema, question):
    system_prompt = "You are a SQL assistant. Given a table schema and a question, reply with ONLY the SQL query, nothing else."
    user_prompt = f"Schema:\n{schema}\n\nQuestion: {question}"
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt")
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=150,
            do_sample=False
        )
    
    
    input_length = inputs.input_ids.shape[1]
    prediction = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip()
    return prediction


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🤖 SQL Assistant (TinyLlama + LoRA)")
    gr.Markdown("قم بإدخال هيكل الجدول وسؤالك باللغة الطبيعية لتحصل على كود SQL فوري.")
    
    with gr.Row():
        with gr.Column():
            schema_input = gr.Textbox(
                label="Database Schema", 
                placeholder="CREATE TABLE users (id INT, name TEXT...);",
                lines=5
            )
            question_input = gr.Textbox(
                label="Your Question", 
                placeholder="List all users older than 25.",
                lines=2
            )
            submit_btn = gr.Button("Generate SQL", variant="primary")
        
        with gr.Column():
            sql_output = gr.Code(label="Generated SQL Query", language="sql")
    
    
    gr.Examples(
        examples=[
            ["CREATE TABLE employees (id INT, name TEXT, salary INT);", "Show names of employees earning more than 5000."],
            ["CREATE TABLE movies (title TEXT, year INT, rating FLOAT);", "Find the highest rated movie from 2022."]
        ],
        inputs=[schema_input, question_input]
    )

    submit_btn.click(fn=generate_sql, inputs=[schema_input, question_input], outputs=sql_output)

demo.launch()