|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from peft import PeftModel |
|
|
|
|
|
|
|
|
BASE_MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct" |
|
|
|
|
|
ADAPTER_MODEL_ID = "manuelaschrittwieser/Qwen2.5-1.5B-SQL-Assistant" |
|
|
|
|
|
print("Lade Modelle... das kann beim ersten Start 1-2 Minuten dauern.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
|
BASE_MODEL_ID, |
|
|
device_map="cpu", |
|
|
torch_dtype=torch.float32 |
|
|
) |
|
|
|
|
|
|
|
|
model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID) |
|
|
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) |
|
|
|
|
|
|
|
|
def generate_sql(context, question): |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": "You are a SQL expert. Generate a SQL query based on the context."}, |
|
|
{"role": "user", "content": f"{context}\nQuestion: {question}"} |
|
|
] |
|
|
|
|
|
|
|
|
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=100, |
|
|
do_sample=False, |
|
|
temperature=0.0 |
|
|
) |
|
|
|
|
|
|
|
|
decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
if "assistant" in decoded_output: |
|
|
response = decoded_output.split("assistant")[-1].strip() |
|
|
else: |
|
|
response = decoded_output |
|
|
|
|
|
return response |
|
|
|
|
|
|
|
|
description = """ |
|
|
# 🤖 Text-to-SQL Assistant |
|
|
Gib ein Datenbankschema (CREATE TABLE) und eine Frage ein, um den passenden SQL-Code zu erhalten. |
|
|
*Trainiert mit QLoRA auf Qwen 2.5 (1.5B).* |
|
|
""" |
|
|
|
|
|
examples = [ |
|
|
[ |
|
|
"CREATE TABLE employees (name VARCHAR, department VARCHAR, salary INT)", |
|
|
"Show me all employees in the Sales department who earn more than 50000." |
|
|
], |
|
|
[ |
|
|
"CREATE TABLE students (id INT, name VARCHAR, grade INT)", |
|
|
"Count how many students are in grade 10." |
|
|
] |
|
|
] |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=generate_sql, |
|
|
inputs=[ |
|
|
gr.Textbox(lines=3, label="Database Context (Schema)", placeholder="CREATE TABLE..."), |
|
|
gr.Textbox(lines=2, label="Question", placeholder="What is...") |
|
|
], |
|
|
outputs=gr.Code(language="sql", label="Generated SQL"), |
|
|
title="Technical Assistant Demo", |
|
|
description=description, |
|
|
examples=examples |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
iface.launch() |