File size: 2,351 Bytes
baac0df
 
 
 
 
 
 
6b53423
baac0df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7931a57
baac0df
 
 
 
7931a57
 
baac0df
 
7931a57
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

print("🚀 Loading model...")

# Load your merged model
model_name = "Abhisek987/llama-3.2-sql-merged"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)

print("✅ Model loaded successfully!")

def generate_sql(database, question):
    """Generate SQL query from natural language question"""
    
    prompt = f"""### Instruction:
You are a SQL expert. Generate a SQL query to answer the given question for the specified database.

### Input:
Database: {database}
Question: {question}

### Response:
"""
    
    # Tokenize and generate
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            temperature=0.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    
    # Decode and extract SQL
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    sql_query = result.split("### Response:")[-1].strip()
    
    return sql_query

# Create Gradio interface
demo = gr.Interface(
    fn=generate_sql,
    inputs=[
        gr.Textbox(
            label="Database Name",
            placeholder="e.g., employees, sales, customers",
            value="employees"
        ),
        gr.Textbox(
            label="Question",
            placeholder="e.g., Show all employees with salary above 60000",
            lines=3
        )
    ],
    outputs=gr.Textbox(
        label="Generated SQL Query",
        lines=5
    ),
    title="🤖 Text-to-SQL Generator",
    description="Fine-tuned Llama 3.2 3B model for SQL query generation using LoRA. Enter a database name and your question in natural language.",
    examples=[
        ["employees", "Show all employees with salary above 60000"],
        ["sales", "Show me the top 5 products by total sales"],
        ["customers", "How many customers are from each country?"],
        ["orders", "Find all orders placed in the last 30 days"]
    ]
)

# Launch without additional parameters for HF Spaces
demo.launch()