Sarathi.AI / app.py
JDhruv14's picture
Update app.py
e51e513 verified
raw
history blame
3.88 kB
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from peft import PeftModel
# ---- IDs (can override from Space Secrets) ----
BASE_ID = os.getenv("BASE_ID", "Qwen/Qwen2.5-3B-Instruct")
ADAPTER_ID = os.getenv("ADAPTER_ID", "JDhruv14/Gita-FT-v2-Qwen2.5-3B")
# ---- Load tokenizer & base model ----
tokenizer = AutoTokenizer.from_pretrained(BASE_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
BASE_ID,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
trust_remote_code=True,
)
# Apply LoRA adapter
model = PeftModel.from_pretrained(model, ADAPTER_ID)
model.eval()
def _eos_ids(tok):
ids = {tok.eos_token_id}
im_end = tok.convert_tokens_to_ids("<|im_end|>")
if im_end is not None:
ids.add(im_end)
return list(ids)
def _format_history(history, system_text):
msgs = []
if system_text:
msgs.append({"role": "system", "content": system_text})
for user, assistant in history:
if user:
msgs.append({"role": "user", "content": user})
if assistant:
msgs.append({"role": "assistant", "content": assistant})
return msgs
@spaces.GPU(duration=120) # keep for ZeroGPU; remove this decorator if using a normal GPU Space
def chat_fn(message, history, system_text, temperature, top_p, max_new_tokens, min_new_tokens):
msgs = _format_history(history, system_text) + [{"role": "user", "content": message}]
prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
gen_cfg = GenerationConfig(
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
max_new_tokens=int(max_new_tokens),
min_new_tokens=int(min_new_tokens),
repetition_penalty=1.02,
no_repeat_ngram_size=3,
eos_token_id=_eos_ids(tokenizer),
pad_token_id=tokenizer.eos_token_id,
)
with torch.no_grad():
outputs = model.generate(**inputs, generation_config=gen_cfg)
# show only the assistant reply (slice off the prompt)
new_tokens = outputs[:, inputs["input_ids"].shape[1]:]
reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
return reply
with gr.Blocks() as demo:
gr.Markdown(
"<h1 style='text-align:center'>Gita Assistant (Qwen2.5-3B + LoRA)</h1>"
"<p style='text-align:center'>Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.</p>"
)
system_box = gr.Textbox(
value="Reply in the user’s language with 2–3 concrete points (200–400 words); cite Gita verses when relevant.",
label="System prompt",
)
temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
max_new = gr.Slider(64, 1024, value=512, step=16, label="max_new_tokens")
min_new = gr.Slider(0, 512, value=160, step=8, label="min_new_tokens")
gr.ChatInterface(
fn=chat_fn, # def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new)
additional_inputs=[system_box, temperature, top_p, max_new, min_new],
chatbot=gr.Chatbot(height=520, type="tuples"), # keep tuple history; no behavior change
examples=[
["How do I practice Nishkama Karma at work?", system_box.value, 0.7, 0.9, 512, 160],
["What does 3.19 teach about duty without attachment?", system_box.value, 0.7, 0.9, 512, 160],
["How to overcome fear of failure according to the Gita?", system_box.value, 0.7, 0.9, 512, 160],
],
)
if __name__ == "__main__":
demo.launch()