Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| print("π Loading model...") | |
| # Load your merged model | |
| model_name = "Abhisek987/llama-3.2-sql-merged" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True | |
| ) | |
| print("β Model loaded successfully!") | |
| def generate_sql(database, question): | |
| """Generate SQL query from natural language question""" | |
| prompt = f"""### Instruction: | |
| You are a SQL expert. Generate a SQL query to answer the given question for the specified database. | |
| ### Input: | |
| Database: {database} | |
| Question: {question} | |
| ### Response: | |
| """ | |
| # Tokenize and generate | |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| temperature=0.1, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| # Decode and extract SQL | |
| result = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| sql_query = result.split("### Response:")[-1].strip() | |
| return sql_query | |
| # Create Gradio interface | |
| demo = gr.Interface( | |
| fn=generate_sql, | |
| inputs=[ | |
| gr.Textbox( | |
| label="Database Name", | |
| placeholder="e.g., employees, sales, customers", | |
| value="employees" | |
| ), | |
| gr.Textbox( | |
| label="Question", | |
| placeholder="e.g., Show all employees with salary above 60000", | |
| lines=3 | |
| ) | |
| ], | |
| outputs=gr.Textbox( | |
| label="Generated SQL Query", | |
| lines=5 | |
| ), | |
| title="π€ Text-to-SQL Generator", | |
| description="Fine-tuned Llama 3.2 3B model for SQL query generation using LoRA. Enter a database name and your question in natural language.", | |
| examples=[ | |
| ["employees", "Show all employees with salary above 60000"], | |
| ["sales", "Show me the top 5 products by total sales"], | |
| ["customers", "How many customers are from each country?"], | |
| ["orders", "Find all orders placed in the last 30 days"] | |
| ] | |
| ) | |
| # Launch without additional parameters for HF Spaces | |
| demo.launch() |