Spaces:
Runtime error
Runtime error
| import pandas as pd | |
| import os | |
| import argparse | |
| import shutil | |
| import tempfile | |
| from google.cloud import storage | |
| from transformers import T5Tokenizer, T5ForConditionalGeneration, Trainer, TrainingArguments | |
| from datasets import Dataset | |
| import torch | |
| # CLI arguments | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--dataset_path", type=str, required=True) | |
| parser.add_argument("--output_dir", type=str, required=True) | |
| args = parser.parse_args() | |
| print("📦 Loading dataset from:", args.dataset_path) | |
| df = pd.read_csv(args.dataset_path) | |
| df = df[["question", "sql"]].rename(columns={"question": "input_text", "sql": "target_text"}) | |
| df["input_text"] = "translate question to SQL: " + df["input_text"] | |
| dataset = Dataset.from_pandas(df) | |
| # Load tokenizer and model | |
| model_name = "t5-small" | |
| tokenizer = T5Tokenizer.from_pretrained(model_name) | |
| model = T5ForConditionalGeneration.from_pretrained(model_name) | |
| def preprocess(example): | |
| input_enc = tokenizer(example["input_text"], truncation=True, padding="max_length", max_length=128) | |
| target_enc = tokenizer(example["target_text"], truncation=True, padding="max_length", max_length=128) | |
| input_enc["labels"] = target_enc["input_ids"] | |
| return input_enc | |
| tokenized_dataset = dataset.map(preprocess) | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir="./results_t5_sqlgen", | |
| per_device_train_batch_size=4, | |
| num_train_epochs=10, | |
| logging_dir="./logs", | |
| logging_steps=5, | |
| save_strategy="epoch", | |
| evaluation_strategy="no" | |
| ) | |
| # Train model | |
| trainer = Trainer(model=model, args=training_args, train_dataset=tokenized_dataset) | |
| trainer.train() | |
| # Save model to temporary local directory | |
| local_dir = tempfile.mkdtemp() | |
| model.save_pretrained(local_dir) | |
| tokenizer.save_pretrained(local_dir) | |
| # Upload all files to GCS | |
| gcs_model_path = os.path.join(args.output_dir, "sqlgen") | |
| bucket_name = gcs_model_path.split("/")[2] | |
| base_path = "/".join(gcs_model_path.split("/")[3:]) | |
| client = storage.Client() | |
| for fname in os.listdir(local_dir): | |
| local_path = os.path.join(local_dir, fname) | |
| gcs_blob_path = os.path.join(base_path, fname) | |
| print(f"⬆️ Uploading {fname} to gs://{bucket_name}/{gcs_blob_path}") | |
| bucket = client.bucket(bucket_name) | |
| blob = bucket.blob(gcs_blob_path) | |
| blob.upload_from_filename(local_path) | |
| print(f"✅ Model successfully uploaded to gs://{bucket_name}/{base_path}") | |