adamabuhamdan's picture
Update app.py
3a92db0 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
LORA_ADAPTER_ID = "adamabuhamdan/tinyllama-sql-lora"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
torch_dtype=torch.float32,
device_map="cpu"
)
model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_ID)
model.eval()
def generate_sql(schema, question):
system_prompt = "You are a SQL assistant. Given a table schema and a question, reply with ONLY the SQL query, nothing else."
user_prompt = f"Schema:\n{schema}\n\nQuestion: {question}"
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=150,
do_sample=False
)
input_length = inputs.input_ids.shape[1]
prediction = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip()
return prediction
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🤖 SQL Assistant (TinyLlama + LoRA)")
gr.Markdown("قم بإدخال هيكل الجدول وسؤالك باللغة الطبيعية لتحصل على كود SQL فوري.")
with gr.Row():
with gr.Column():
schema_input = gr.Textbox(
label="Database Schema",
placeholder="CREATE TABLE users (id INT, name TEXT...);",
lines=5
)
question_input = gr.Textbox(
label="Your Question",
placeholder="List all users older than 25.",
lines=2
)
submit_btn = gr.Button("Generate SQL", variant="primary")
with gr.Column():
sql_output = gr.Code(label="Generated SQL Query", language="sql")
gr.Examples(
examples=[
["CREATE TABLE employees (id INT, name TEXT, salary INT);", "Show names of employees earning more than 5000."],
["CREATE TABLE movies (title TEXT, year INT, rating FLOAT);", "Find the highest rated movie from 2022."]
],
inputs=[schema_input, question_input]
)
submit_btn.click(fn=generate_sql, inputs=[schema_input, question_input], outputs=sql_output)
demo.launch()