tahamajs's picture
Update Readme
e307db1 verified
---
library_name: peft
base_model: GSAI-ML/LLaDA-8B-Instruct
tags:
- text-to-sql
- diffusion
- llada
- lora
- qlora
- generated_from_trainer
datasets:
- gretelai/synthetic_text_to_sql
license: apache-2.0
language:
- en
---
# LLaDA-8B Text-to-SQL (Diffusion-based)
## Model Summary
This model is a **Text-to-SQL** adapter fine-tuned on the `GSAI-ML/LLaDA-8B-Instruct` base model. Unlike traditional Autoregressive (AR) models that generate tokens left-to-right, this model uses **Masked Iterative Generation (Diffusion)**.
It treats text generation as a diffusion process: starting with a fully masked sequence and iteratively refining/unmasking tokens based on confidence scores. This allows for bi-directional context utilization during generation.
- **Task:** Text-to-SQL (Converting natural language questions + schema into SQL queries).
- **Method:** LLaDA (Large Language Diffusion with Autoregression) with Block Diffusion Sampling.
- **Fine-Tuning:** QLoRA (4-bit Quantization + LoRA).
## Model Details
- **Developed by:** [Tahamajs/Organization]
- **Base Model:** [GSAI-ML/LLaDA-8B-Instruct](https://huggingface.co/GSAI-ML/LLaDA-8B-Instruct)
- **Dataset:** [gretelai/synthetic_text_to_sql](https://huggingface.co/datasets/gretelai/synthetic_text_to_sql) (Subset of 20k samples)
- **Language:** English (Natural Language) -> SQL
- **Generation Strategy:** Semi-Autoregressive / Block Diffusion
## How to Use (Inference Code)
**Note:** This model does *not* work with the standard `model.generate()` function because it requires a custom diffusion sampling loop. Use the code below to generate SQL queries.
### 1. Setup & Loading
```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel, PeftConfig
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1. Load Base Model (4-bit)
base_model_id = "GSAI-ML/LLaDA-8B-Instruct"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
use_cache=False
)
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
# 2. Load LoRA Adapter (This Repo)
adapter_model_id = "YOUR_USERNAME/llada-text-to-sql-lora" # Replace with your repo name
model = PeftModel.from_pretrained(model, adapter_model_id)
model.eval()
```
### 2. Define Block Diffusion Generation
```python
@torch.no_grad()
def generate_block_diffusion(model, tokenizer, prompt_text, steps=32, gen_len=64):
"""
Generates text using LLaDA's block diffusion strategy.
"""
# Tokenize Prompt
prompt_ids = tokenizer.encode(prompt_text, return_tensors='pt').to(model.device)
prompt_len = prompt_ids.shape[1]
# Initialize Response with [MASK] tokens
mask_ids = torch.full((1, gen_len), tokenizer.mask_token_id, device=model.device)
input_ids = torch.cat([prompt_ids, mask_ids], dim=1)
# Track unknown indices (initially all response tokens)
unknown_indices = set(range(prompt_len, input_ids.shape[1]))
tokens_to_lock_per_step = gen_len // steps
for step in range(steps):
# Forward pass
outputs = model(input_ids)
probs = torch.softmax(outputs.logits, dim=-1)
# Get most confident predictions
confidences, predicted_ids = torch.max(probs, dim=-1)
# Identify which tokens to "lock in" this step
candidates = []
current_unknowns = list(unknown_indices)
if not current_unknowns: break
for idx in current_unknowns:
score = confidences[0, idx].item()
token = predicted_ids[0, idx].item()
candidates.append((score, idx, token))
# Sort by confidence and pick top k
candidates.sort(key=lambda x: x[0], reverse=True)
top_k = candidates[:tokens_to_lock_per_step]
# Update input_ids
for _, idx, token in top_k:
input_ids[0, idx] = token
unknown_indices.remove(idx)
# Decode only the generated part
return tokenizer.decode(input_ids[0, prompt_len:], skip_special_tokens=True)
```
### 3. Run Inference
```python
schema = "CREATE TABLE users (id INTEGER, name TEXT, age INTEGER);"
question = "Show me the names of users older than 25."
prompt = f"""
<|im_start|>system
You are a Text-to-SQL assistant. Output ONLY the SQL query. Do not add explanations.<|im_end|>
<|im_start|>user
Schema:
{schema}
Question:
{question}<|im_end|>
<|im_start|>assistant
"""
output = generate_block_diffusion(model, tokenizer, prompt, steps=32, gen_len=64)
print("Generated SQL:", output)
```
## Training Details
### Training Configuration
* **Epochs:** 5
* **Batch Size:** 2 (Effective Batch Size = 8 via Gradient Accumulation)
* **Optimizer:** AdamW (lr=2e-4)
* **Scheduler:** Linear with Warmup (50 steps)
* **Context Length:** 384 tokens
* **Precision:** fp16 (via Mixed Precision)
### Noise Schedule
Training used a **Forward Masking** process where tokens in the answer were randomly replaced with `[MASK]` based on a uniform time step . Loss was calculated only on masked tokens and reweighted by .
### LoRA Configuration
* **Rank (r):** 16
* **Alpha:** 32
* **Target Modules:** `q_proj`, `v_proj`
* **Dropout:** 0.05
## Evaluation Results
Evaluated on the `gretelai/synthetic_text_to_sql` test set (200 samples) using Block Diffusion sampling.
| Metric | Score |
| --- | --- |
| **Exact Match (EM)** | ~30% |
| **Normalized EM** | ~35-40%* |
**Scores may vary depending on post-processing strictness and SQL normalization logic.*
## Bias, Risks, and Limitations
* **"Chatty" Output:** The model sometimes fails to produce an EOS token immediately after the semicolon, occasionally repeating the query or adding conversational filler. Post-processing (regex extraction of `SELECT ... ;`) is recommended.
* **Hallucination:** In complex queries, the model may occasionally hallucinate columns that do not exist in the provided schema if the schema context is too long or complex.
* **Inference Speed:** Due to the iterative nature of Block Diffusion (multiple forward passes per generation), inference is slower than standard Autoregressive models of the same size.
## Citation
If you use this model or the LLaDA technique, please cite the original paper:
```bibtex
@article{nie2024llada,
title={LLaDA: Large Language Diffusion with Autoregression},
author={Nie, Shen and others},
journal={arXiv preprint arXiv:2402.XXXXX},
year={2024}
}
```
```