Spaces:
Runtime error
Runtime error
| import re, torch, gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| model_id = "kirankotha/mistral7b-sql-model" | |
| tok = AutoTokenizer.from_pretrained(model_id) | |
| mdl = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto") | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| def gen_sql(question, context): | |
| prompt = f"You are a text-to-SQL model.\n### Input:\n{question}\n### Context:\n{context}\n### Response:\n" | |
| ids = tok(prompt, return_tensors="pt").to(mdl.device) | |
| with torch.no_grad(): | |
| out = mdl.generate( | |
| **ids, | |
| max_new_tokens=128, | |
| do_sample=False, | |
| pad_token_id=tok.pad_token_id | |
| ) | |
| text = tok.decode(out[0], skip_special_tokens=True) | |
| m = re.findall(r"```(.*?)```", text, flags=re.DOTALL) | |
| return m[-1].strip() if m else text.strip() | |
| demo = gr.Interface( | |
| fn=gen_sql, | |
| inputs=["text","text"], | |
| outputs="text", | |
| title="Mistral-7B Text-to-SQL Demo", | |
| description="Fine-tuned Mistral-7B model for text-to-SQL generation.", | |
| examples=[["Which product has the highest price?", | |
| "CREATE TABLE products (id INTEGER, name TEXT, price REAL)"]] | |
| ) | |
| demo.launch() | |