File size: 3,392 Bytes
972aab5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
import sqlite3
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
import database
import json

# Initialize database
database.init_database()

# Get schema information
schema_info = database.get_schema_info()

# Initialize the model and tokenizer
@st.cache_resource
def load_model():
    tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf")
    model = AutoModelForCausalLM.from_pretrained("codellama/CodeLlama-7b-hf")
    return model, tokenizer

def create_schema_prompt():
    prompt = "Database Schema:\n"
    for table, info in schema_info.items():
        prompt += f"\nTable: {table}\n"
        prompt += "Columns:\n"
        for col, type_ in zip(info['columns'], info['types']):
            sample_values = info['sample_values'][col][:3]  # Take first 3 sample values
            prompt += f"- {col} ({type_}), Example values: {', '.join(map(str, sample_values))}\n"
    return prompt

def generate_sql_query(question):
    model, tokenizer = load_model()
    
    # Create detailed prompt with schema information
    schema_prompt = create_schema_prompt()
    prompt = f"""Given the following database schema and question, generate a SQL query that answers the question.

{schema_prompt}

Question: {question}

Write only the SQL query without any additional text or explanation. Make sure to:
1. Use the correct table and column names as shown in the schema
2. Handle joins appropriately if multiple tables are needed
3. Use appropriate SQL functions based on the question context

SQL Query:"""
    
    # Generate SQL query
    inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
    outputs = model.generate(
        **inputs,
        max_length=500,
        num_return_sequences=1,
        temperature=0.7,
        top_p=0.95,
        do_sample=True
    )
    sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract only the SQL part
    sql_query = sql_query.split("SQL Query:")[-1].strip()
    return sql_query

def execute_query(query):
    conn = sqlite3.connect('data.db')
    try:
        result = pd.read_sql_query(query, conn)
        return result, None
    except Exception as e:
        return None, str(e)
    finally:
        conn.close()

# Streamlit UI
st.title("Intelligent Text to SQL Query Assistant")
st.write("Ask questions about your data in natural language!")

# Display schema information in expandable section
with st.expander("View Database Schema"):
    st.code(create_schema_prompt(), language="text")

# User input
user_question = st.text_area("Enter your question:", height=100)

if st.button("Generate and Execute Query"):
    if user_question:
        with st.spinner("Generating SQL query..."):
            # Generate SQL query
            sql_query = generate_sql_query(user_question)
            
            # Display the generated query
            st.subheader("Generated SQL Query:")
            st.code(sql_query, language="sql")
            
            # Execute the query
            with st.spinner("Executing query..."):
                results, error = execute_query(sql_query)
                
                if error:
                    st.error(f"Error executing query: {error}")
                else:
                    st.subheader("Query Results:")
                    st.dataframe(results)