souvik16011991roy commited on
Commit
972aab5
·
verified ·
1 Parent(s): fbf23ae

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. app.py +104 -0
  3. data.db +3 -0
  4. database.py +88 -0
  5. requirements.txt +7 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data.db filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import sqlite3
3
+ import pandas as pd
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ import database
6
+ import json
7
+
8
+ # Initialize database
9
+ database.init_database()
10
+
11
+ # Get schema information
12
+ schema_info = database.get_schema_info()
13
+
14
+ # Initialize the model and tokenizer
15
+ @st.cache_resource
16
+ def load_model():
17
+ tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf")
18
+ model = AutoModelForCausalLM.from_pretrained("codellama/CodeLlama-7b-hf")
19
+ return model, tokenizer
20
+
21
+ def create_schema_prompt():
22
+ prompt = "Database Schema:\n"
23
+ for table, info in schema_info.items():
24
+ prompt += f"\nTable: {table}\n"
25
+ prompt += "Columns:\n"
26
+ for col, type_ in zip(info['columns'], info['types']):
27
+ sample_values = info['sample_values'][col][:3] # Take first 3 sample values
28
+ prompt += f"- {col} ({type_}), Example values: {', '.join(map(str, sample_values))}\n"
29
+ return prompt
30
+
31
+ def generate_sql_query(question):
32
+ model, tokenizer = load_model()
33
+
34
+ # Create detailed prompt with schema information
35
+ schema_prompt = create_schema_prompt()
36
+ prompt = f"""Given the following database schema and question, generate a SQL query that answers the question.
37
+
38
+ {schema_prompt}
39
+
40
+ Question: {question}
41
+
42
+ Write only the SQL query without any additional text or explanation. Make sure to:
43
+ 1. Use the correct table and column names as shown in the schema
44
+ 2. Handle joins appropriately if multiple tables are needed
45
+ 3. Use appropriate SQL functions based on the question context
46
+
47
+ SQL Query:"""
48
+
49
+ # Generate SQL query
50
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
51
+ outputs = model.generate(
52
+ **inputs,
53
+ max_length=500,
54
+ num_return_sequences=1,
55
+ temperature=0.7,
56
+ top_p=0.95,
57
+ do_sample=True
58
+ )
59
+ sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
60
+
61
+ # Extract only the SQL part
62
+ sql_query = sql_query.split("SQL Query:")[-1].strip()
63
+ return sql_query
64
+
65
+ def execute_query(query):
66
+ conn = sqlite3.connect('data.db')
67
+ try:
68
+ result = pd.read_sql_query(query, conn)
69
+ return result, None
70
+ except Exception as e:
71
+ return None, str(e)
72
+ finally:
73
+ conn.close()
74
+
75
+ # Streamlit UI
76
+ st.title("Intelligent Text to SQL Query Assistant")
77
+ st.write("Ask questions about your data in natural language!")
78
+
79
+ # Display schema information in expandable section
80
+ with st.expander("View Database Schema"):
81
+ st.code(create_schema_prompt(), language="text")
82
+
83
+ # User input
84
+ user_question = st.text_area("Enter your question:", height=100)
85
+
86
+ if st.button("Generate and Execute Query"):
87
+ if user_question:
88
+ with st.spinner("Generating SQL query..."):
89
+ # Generate SQL query
90
+ sql_query = generate_sql_query(user_question)
91
+
92
+ # Display the generated query
93
+ st.subheader("Generated SQL Query:")
94
+ st.code(sql_query, language="sql")
95
+
96
+ # Execute the query
97
+ with st.spinner("Executing query..."):
98
+ results, error = execute_query(sql_query)
99
+
100
+ if error:
101
+ st.error(f"Error executing query: {error}")
102
+ else:
103
+ st.subheader("Query Results:")
104
+ st.dataframe(results)
data.db ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:426aac0b39b5ca888406c7471807463f82bc25b68e8568a6af16456cae01abf0
3
+ size 438272
database.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import pandas as pd
3
+ import os
4
+ from huggingface_hub import hf_hub_download
5
+ import io
6
+ import requests
7
+ from io import StringIO
8
+
9
+ def download_dataset(url):
10
+ # Convert URL to raw content URL
11
+ raw_url = url.replace('blob/', '')
12
+ raw_url = raw_url.replace('https://huggingface.co/', 'https://huggingface.co/')
13
+ raw_url = raw_url.replace('/tree/main', '/resolve/main')
14
+
15
+ headers = {
16
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
17
+ }
18
+
19
+ # Download the data
20
+ response = requests.get(raw_url, headers=headers)
21
+ response.raise_for_status() # Raise an exception for bad status codes
22
+
23
+ # Read CSV data
24
+ return pd.read_csv(StringIO(response.text))
25
+
26
+ def init_database():
27
+ # Create database connection
28
+ conn = sqlite3.connect('data.db')
29
+
30
+ try:
31
+ # Download files from Hugging Face
32
+ bonus_data_path = hf_hub_download(
33
+ repo_id="AIforAll16011991/bonus_data",
34
+ filename="Bonus_Data.csv",
35
+ repo_type="dataset"
36
+ )
37
+
38
+ player_kpi_path = hf_hub_download(
39
+ repo_id="AIforAll16011991/bonus_data",
40
+ filename="Player_KPIs.csv",
41
+ repo_type="dataset"
42
+ )
43
+
44
+ # Read CSV files
45
+ bonus_data = pd.read_csv(bonus_data_path)
46
+ player_kpi = pd.read_csv(player_kpi_path)
47
+
48
+ # Write to SQLite database
49
+ bonus_data.to_sql('bonus_data', conn, if_exists='replace', index=False)
50
+ player_kpi.to_sql('player_kpi', conn, if_exists='replace', index=False)
51
+
52
+ print("Database initialized successfully with data from Hugging Face!")
53
+
54
+ except Exception as e:
55
+ print(f"Error initializing database: {str(e)}")
56
+ raise
57
+ finally:
58
+ conn.close()
59
+
60
+ def get_schema_info():
61
+ conn = sqlite3.connect('data.db')
62
+ cursor = conn.cursor()
63
+
64
+ schema_info = {}
65
+
66
+ # Get table information
67
+ tables = cursor.execute("SELECT name FROM sqlite_master WHERE type='table';").fetchall()
68
+
69
+ for table in tables:
70
+ table_name = table[0]
71
+ # Get column information
72
+ columns = cursor.execute(f"PRAGMA table_info({table_name});").fetchall()
73
+
74
+ # Get sample data for each column
75
+ sample_data = pd.read_sql_query(f"SELECT * FROM {table_name} LIMIT 5", conn)
76
+
77
+ # Store column information and data types
78
+ schema_info[table_name] = {
79
+ 'columns': [col[1] for col in columns],
80
+ 'types': [col[2] for col in columns],
81
+ 'sample_values': {col: sample_data[col].tolist() for col in sample_data.columns}
82
+ }
83
+
84
+ conn.close()
85
+ return schema_info
86
+
87
+ if __name__ == "__main__":
88
+ init_database()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ streamlit==1.31.1
2
+ transformers==4.37.2
3
+ torch==2.2.0
4
+ pandas==2.2.0
5
+ sqlite3-utils==3.35.2
6
+ requests==2.31.0
7
+ huggingface_hub==0.21.4