import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel BASE_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" LORA_ADAPTER_ID = "adamabuhamdan/tinyllama-sql-lora" tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, torch_dtype=torch.float32, device_map="cpu" ) model = PeftModel.from_pretrained(base_model, LORA_ADAPTER_ID) model.eval() def generate_sql(schema, question): system_prompt = "You are a SQL assistant. Given a table schema and a question, reply with ONLY the SQL query, nothing else." user_prompt = f"Schema:\n{schema}\n\nQuestion: {question}" messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(prompt, return_tensors="pt") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=150, do_sample=False ) input_length = inputs.input_ids.shape[1] prediction = tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True).strip() return prediction with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🤖 SQL Assistant (TinyLlama + LoRA)") gr.Markdown("قم بإدخال هيكل الجدول وسؤالك باللغة الطبيعية لتحصل على كود SQL فوري.") with gr.Row(): with gr.Column(): schema_input = gr.Textbox( label="Database Schema", placeholder="CREATE TABLE users (id INT, name TEXT...);", lines=5 ) question_input = gr.Textbox( label="Your Question", placeholder="List all users older than 25.", lines=2 ) submit_btn = gr.Button("Generate SQL", variant="primary") with gr.Column(): sql_output = gr.Code(label="Generated SQL Query", language="sql") gr.Examples( examples=[ ["CREATE TABLE employees (id INT, name TEXT, salary INT);", "Show names of employees earning more than 5000."], ["CREATE TABLE movies (title TEXT, year INT, rating FLOAT);", "Find the highest rated movie from 2022."] ], inputs=[schema_input, question_input] ) submit_btn.click(fn=generate_sql, inputs=[schema_input, question_input], outputs=sql_output) demo.launch()