Spaces:
Running
Running
| import torch | |
| import sqlite3 | |
| import pandas as pd | |
| import gradio as gr | |
| from langchain_community.llms import HuggingFacePipeline | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| # ============================================================ | |
| # π Load SQLCoder model | |
| # ============================================================ | |
| model_id = "defog/sqlcoder-7b-2" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype="auto", | |
| device_map="auto" | |
| ) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=256, | |
| do_sample=False | |
| ) | |
| sqlcoder_llm = HuggingFacePipeline(pipeline=pipe) | |
| # ============================================================ | |
| # π§ Define query function | |
| # ============================================================ | |
| def ask_question(user_db, question): | |
| """Takes an uploaded SQLite database + a question, returns SQL + result""" | |
| if not user_db: | |
| return "β Please upload a database file.", None | |
| conn = sqlite3.connect(user_db.name) | |
| cursor = conn.cursor() | |
| # Create a Text-to-SQL prompt | |
| prompt = f""" | |
| You are an expert SQL generator. | |
| The database follows the Chinook schema with tables: | |
| customers, invoices, invoice_items, tracks, albums, artists, employees, genres, media_types, playlists, playlist_track. | |
| Translate this question into a valid SQLite query for this schema. | |
| Return only SQL (no text). | |
| Question: {question} | |
| SQL: | |
| """ | |
| # β Use .invoke() instead of calling the object directly | |
| response = sqlcoder_llm.invoke(prompt) | |
| # Ensure we get plain string | |
| if isinstance(response, dict) and "text" in response: | |
| response = response["text"] | |
| elif isinstance(response, list): | |
| response = response[0]["generated_text"] | |
| # Clean and finalize SQL | |
| sql_query = response.strip().split("SQL:")[-1].strip() | |
| sql_query = sql_query.split("\n")[0].strip() | |
| if not sql_query.endswith(";"): | |
| sql_query += ";" | |
| try: | |
| cursor.execute(sql_query) | |
| rows = cursor.fetchall() | |
| columns = [desc[0] for desc in cursor.description] | |
| df = pd.DataFrame(rows, columns=columns) | |
| conn.close() | |
| return sql_query, df | |
| except Exception as e: | |
| conn.close() | |
| return f"β Error executing query: {e}\n\nGenerated SQL:\n{sql_query}", None | |
| # ============================================================ | |
| # π¨ Gradio UI | |
| # ============================================================ | |
| demo = gr.Interface( | |
| fn=ask_question, | |
| inputs=[ | |
| gr.File(label="Upload SQLite Database (.db)"), | |
| gr.Textbox(label="Ask your question") | |
| ], | |
| outputs=[ | |
| gr.Textbox(label="Generated SQL Query"), | |
| gr.Dataframe(label="Query Result") | |
| ], | |
| title="π§ Text-to-SQL on Your Own Database", | |
| description="Upload your SQLite database and ask natural language questions." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |