Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import sqlite3 | |
| from datasets import Dataset | |
| from transformers import T5Tokenizer | |
| # ========================================================= | |
| # PROJECT ROOT (VERY IMPORTANT — fixes path issues) | |
| # ========================================================= | |
| BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| TRAIN_JSON = os.path.join(BASE_DIR, "data", "train_spider.json") | |
| DEV_JSON = os.path.join(BASE_DIR, "data", "dev.json") | |
| DB_FOLDER = os.path.join(BASE_DIR, "data", "database") | |
| SAVE_TRAIN = os.path.join(BASE_DIR, "data", "tokenized", "train") | |
| SAVE_DEV = os.path.join(BASE_DIR, "data", "tokenized", "validation") | |
| os.makedirs(os.path.dirname(SAVE_TRAIN), exist_ok=True) | |
| print("Project root:", BASE_DIR) | |
| print("Train file:", TRAIN_JSON) | |
| print("Database folder:", DB_FOLDER) | |
| # ========================================================= | |
| # TOKENIZER | |
| # ========================================================= | |
| tokenizer = T5Tokenizer.from_pretrained("t5-small") | |
| # ========================================================= | |
| # READ DATABASE SCHEMA | |
| # ========================================================= | |
| def get_schema(db_path): | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| tables = cursor.execute( | |
| "SELECT name FROM sqlite_master WHERE type='table';" | |
| ).fetchall() | |
| schema_text = [] | |
| for table in tables: | |
| table = table[0] | |
| columns = cursor.execute(f"PRAGMA table_info({table});").fetchall() | |
| col_names = [c[1] for c in columns] | |
| schema_text.append(f"{table}({', '.join(col_names)})") | |
| conn.close() | |
| return "\n".join(schema_text) | |
| # ========================================================= | |
| # BUILD TRAINING EXAMPLES | |
| # ========================================================= | |
| def build_examples(spider_json): | |
| print(f"\nBuilding dataset from: {spider_json}") | |
| data = json.load(open(spider_json)) | |
| inputs = [] | |
| outputs = [] | |
| for ex in data: | |
| question = ex["question"] | |
| sql = ex["query"] | |
| db_id = ex["db_id"] | |
| db_path = os.path.join(DB_FOLDER, db_id, f"{db_id}.sqlite") | |
| # skip if db missing (safety) | |
| if not os.path.exists(db_path): | |
| continue | |
| schema = get_schema(db_path) | |
| # ⭐ SCHEMA-AWARE PROMPT (VERY IMPORTANT) | |
| input_text = f"""Database Schema: | |
| {schema} | |
| Translate English to SQL: | |
| {question} | |
| SQL: | |
| """ | |
| inputs.append(input_text) | |
| outputs.append(sql) | |
| return Dataset.from_dict({"input": inputs, "output": outputs}) | |
| # ========================================================= | |
| # TOKENIZE | |
| # ========================================================= | |
| def tokenize(example): | |
| model_input = tokenizer( | |
| example["input"], | |
| max_length=512, | |
| padding="max_length", | |
| truncation=True | |
| ) | |
| label = tokenizer( | |
| example["output"], | |
| max_length=256, | |
| padding="max_length", | |
| truncation=True | |
| ) | |
| model_input["labels"] = label["input_ids"] | |
| return model_input | |
| # ========================================================= | |
| # RUN PIPELINE | |
| # ========================================================= | |
| print("\nBuilding TRAIN dataset...") | |
| train_dataset = build_examples(TRAIN_JSON) | |
| print("Tokenizing TRAIN dataset...") | |
| tokenized_train = train_dataset.map(tokenize, batched=False) | |
| print("Saving TRAIN dataset...") | |
| tokenized_train.save_to_disk(SAVE_TRAIN) | |
| print("\nBuilding VALIDATION dataset...") | |
| val_dataset = build_examples(DEV_JSON) | |
| print("Tokenizing VALIDATION dataset...") | |
| tokenized_val = val_dataset.map(tokenize, batched=False) | |
| print("Saving VALIDATION dataset...") | |
| tokenized_val.save_to_disk(SAVE_DEV) | |
| print("\nDONE ✔ Dataset prepared successfully!") | |
| print("Train saved at:", SAVE_TRAIN) | |
| print("Validation saved at:", SAVE_DEV) |