Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import T5Tokenizer, T5ForConditionalGeneration | |
| import os | |
| import pandas as pd | |
| # Get project root directory | |
| PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
| def load_model(): | |
| print("📦 Loading pre-trained text-to-SQL model...") | |
| model_name = "cssupport/t5-small-awesome-text-to-sql" | |
| tokenizer = T5Tokenizer.from_pretrained(model_name) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = T5ForConditionalGeneration.from_pretrained(model_name) | |
| model = model.to(device) | |
| model.eval() | |
| return model, tokenizer, device | |
| def generate_sql(question, schema, model, tokenizer, device): | |
| # Format input as expected by the model | |
| input_prompt = f"tables:\n{schema}\nquery for: {question}" | |
| # Tokenize the input prompt | |
| inputs = tokenizer(input_prompt, padding=True, truncation=True, return_tensors="pt").to(device) | |
| # Generate SQL | |
| with torch.no_grad(): | |
| outputs = model.generate(**inputs, max_length=512) | |
| # Decode the output | |
| generated_sql = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return generated_sql | |
| def get_schema_from_csv(csv_path): | |
| """Generate CREATE TABLE statements from CSV file""" | |
| df = pd.read_csv(csv_path) | |
| columns = [] | |
| for col in df.columns: | |
| # Infer column type | |
| dtype = df[col].dtype | |
| if dtype == 'int64': | |
| col_type = 'INT' | |
| elif dtype == 'float64': | |
| col_type = 'DECIMAL(10,2)' | |
| else: | |
| col_type = 'VARCHAR(255)' | |
| columns.append(f"{col} {col_type}") | |
| table_name = os.path.splitext(os.path.basename(csv_path))[0] | |
| create_table = f"CREATE TABLE {table_name} (\n " + ",\n ".join(columns) + "\n);" | |
| return create_table | |
| if __name__ == "__main__": | |
| # Load the pre-trained model | |
| model, tokenizer, device = load_model() | |
| # Save the model locally for future use | |
| output_dir = os.path.join(PROJECT_ROOT, "model_sqlgen_t5") | |
| print(f"💾 Saving model to {output_dir}") | |
| os.makedirs(output_dir, exist_ok=True) | |
| model.save_pretrained(output_dir) | |
| tokenizer.save_pretrained(output_dir) | |
| print(f"✅ Model successfully saved to {output_dir}") | |
| # Example usage with CSV | |
| csv_path = os.path.join(PROJECT_ROOT, "data", "retail_dataset.csv") | |
| if os.path.exists(csv_path): | |
| schema = get_schema_from_csv(csv_path) | |
| print("\nGenerated schema from CSV:") | |
| print(schema) | |
| question = "What is the total sales amount for each product category?" | |
| sql_query = generate_sql(question, schema, model, tokenizer, device) | |
| print("\nExample usage:") | |
| print(f"Question: {question}") | |
| print(f"Generated SQL: {sql_query}") | |