File size: 854 Bytes
b5c5ad5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

from peft import PeftModel, PeftConfig
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

# Load the model
config = PeftConfig.from_pretrained("shubh7/T5-Small-FineTuned-TexttoSql")
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, "shubh7/T5-Small-FineTuned-TexttoSql")
tokenizer = AutoTokenizer.from_pretrained("shubh7/T5-Small-FineTuned-TexttoSql")

# Sample inference
def generate_sql(question, max_length=128):
    inputs = tokenizer(question, return_tensors="pt", padding=True)
    outputs = model.generate(**inputs, max_length=max_length)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Example
question = "Find all customers who placed orders in the last month"
sql = generate_sql(question)
print(f"Question: {question}")
print(f"SQL: {sql}")