Spaces:
Runtime error
Runtime error
File size: 3,876 Bytes
318503e 8a5aaef 26009d1 8a5aaef 26009d1 318503e 8a5aaef 318503e 8a5aaef 26009d1 8a5aaef 318503e 26009d1 8a5aaef 318503e 26009d1 318503e 8a5aaef 318503e 8a5aaef 26009d1 8a5aaef 318503e 26009d1 318503e 8a5aaef 318503e 0d61b36 8a5aaef 318503e 8a5aaef 26009d1 8a5aaef 114b1bc 318503e 114b1bc 0d61b36 114b1bc 0d61b36 318503e 26009d1 0d61b36 114b1bc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from peft import PeftModel
# ---- IDs (can override from Space Secrets) ----
BASE_ID = os.getenv("BASE_ID", "Qwen/Qwen2.5-3B-Instruct")
ADAPTER_ID = os.getenv("ADAPTER_ID", "JDhruv14/Gita-FT-v2-Qwen2.5-3B")
# ---- Load tokenizer & base model ----
tokenizer = AutoTokenizer.from_pretrained(BASE_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
BASE_ID,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
trust_remote_code=True,
)
# Apply LoRA adapter
model = PeftModel.from_pretrained(model, ADAPTER_ID)
model.eval()
def _eos_ids(tok):
ids = {tok.eos_token_id}
im_end = tok.convert_tokens_to_ids("<|im_end|>")
if im_end is not None:
ids.add(im_end)
return list(ids)
def _format_history(history, system_text):
msgs = []
if system_text:
msgs.append({"role": "system", "content": system_text})
for user, assistant in history:
if user:
msgs.append({"role": "user", "content": user})
if assistant:
msgs.append({"role": "assistant", "content": assistant})
return msgs
@spaces.GPU(duration=120) # keep for ZeroGPU; remove this decorator if using a normal GPU Space
def chat_fn(message, history, system_text, temperature, top_p, max_new_tokens, min_new_tokens):
msgs = _format_history(history, system_text) + [{"role": "user", "content": message}]
prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
gen_cfg = GenerationConfig(
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
max_new_tokens=int(max_new_tokens),
min_new_tokens=int(min_new_tokens),
repetition_penalty=1.02,
no_repeat_ngram_size=3,
eos_token_id=_eos_ids(tokenizer),
pad_token_id=tokenizer.eos_token_id,
)
with torch.no_grad():
outputs = model.generate(**inputs, generation_config=gen_cfg)
# show only the assistant reply (slice off the prompt)
new_tokens = outputs[:, inputs["input_ids"].shape[1]:]
reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
return reply
with gr.Blocks() as demo:
gr.Markdown(
"<h1 style='text-align:center'>Gita Assistant (Qwen2.5-3B + LoRA)</h1>"
"<p style='text-align:center'>Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.</p>"
)
system_box = gr.Textbox(
value="Reply in the user’s language with 2–3 concrete points (200–400 words); cite Gita verses when relevant.",
label="System prompt",
)
temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
max_new = gr.Slider(64, 1024, value=512, step=16, label="max_new_tokens")
min_new = gr.Slider(0, 512, value=160, step=8, label="min_new_tokens")
gr.ChatInterface(
fn=chat_fn, # def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new)
additional_inputs=[system_box, temperature, top_p, max_new, min_new],
chatbot=gr.Chatbot(height=520, type="tuples"), # keep tuple history; no behavior change
examples=[
["How do I practice Nishkama Karma at work?", system_box.value, 0.7, 0.9, 512, 160],
["What does 3.19 teach about duty without attachment?", system_box.value, 0.7, 0.9, 512, 160],
["How to overcome fear of failure according to the Gita?", system_box.value, 0.7, 0.9, 512, 160],
],
)
if __name__ == "__main__":
demo.launch()
|