ayurveda-SQL / app.py
joise-s-arakkal's picture
Update app.py
e6c1b90 verified
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import torch
import gradio as gr
import pandas as pd
import re
from pandasql import sqldf
base_model = "microsoft/phi-3-mini-4k-instruct"
adapter_model = "r929rrjq/phi3-mini-yoda-adapter"
data = pd.read_csv("AyurGenixAI_Dataset.csv")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model)
# Load base model without bitsandbytes
model = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=torch.float32, # Use float16 only if GPU is available
)
# Apply the LoRA adapter
model = PeftModel.from_pretrained(model, adapter_model)
model.eval()
def chat(prompt):
structured_prompt = f"<|user|>{prompt}<|end|><|assistant|>"
inputs = tokenizer(structured_prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=200,
do_sample=True,
temperature=0.3,
top_k=20,
top_p=0.8
)
my_sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated SQL:", my_sql_query)
try:
query = re.search(r"SELECT .*", my_sql_query, re.IGNORECASE).group(0)
print("Parsed SQL Query:", query)
result_df = sqldf(query, {"data": data})
except Exception as e:
print("Error:", e)
result_df = pd.DataFrame({"Error": [str(e)]}) # Return error in table form
return result_df
gr.Interface(
fn=chat,
inputs="text",
outputs=gr.Dataframe(label="Query Result"),
title="Ayurvedic Chatbot"
).launch()