Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,394 Bytes
33dd5ba 8a5aaef 33dd5ba 26009d1 33dd5ba 318503e 33dd5ba 318503e 26009d1 33dd5ba 26009d1 33dd5ba 26009d1 33dd5ba 26009d1 33dd5ba 26009d1 33dd5ba 0d61b36 33dd5ba e8c693f 33dd5ba 318503e 33dd5ba 318503e 33dd5ba e8c693f 33dd5ba 26009d1 b81b02b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/merged_model")
# Load once (CPU until first call; device_map will move to GPU on first run)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
trust_remote_code=True,
)
def _msgs_from_history(history, system_text):
msgs = []
if system_text:
msgs.append({"role": "system", "content": system_text})
for user, assistant in history:
if user:
msgs.append({"role": "user", "content": user})
if assistant:
msgs.append({"role": "assistant", "content": assistant})
return msgs
def _eos_ids(tok):
ids = {tok.eos_token_id}
im_end = tok.convert_tokens_to_ids("<|im_end|>")
if im_end is not None:
ids.add(im_end)
return list(ids)
@spaces.GPU(duration=120) # REQUIRED for ZeroGPU; remove if using standard GPU hardware
def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
gen_cfg = GenerationConfig(
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
max_new_tokens=int(max_new),
min_new_tokens=int(min_new),
repetition_penalty=1.02,
no_repeat_ngram_size=3,
eos_token_id=_eos_ids(tokenizer),
pad_token_id=tokenizer.eos_token_id,
)
with torch.no_grad():
out = model.generate(**inputs, generation_config=gen_cfg)
# slice off the prompt so we show only the assistant reply
new_tokens = out[:, inputs["input_ids"].shape[1]:]
reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
return reply
with gr.Blocks() as demo:
gr.Markdown(
"<h1 style='text-align:center'>Gita Assistant (Qwen2.5-3B Fine-tuned)</h1>"
"<p style='text-align:center'>Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.</p>"
)
system_box = gr.Textbox(
value="Reply in the user’s language with 2–3 concise points (200–400 words); cite Gita verses when relevant.",
label="System prompt",
)
temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
max_new = gr.Slider(64, 1024, value=512, step=16, label="max_new_tokens")
min_new = gr.Slider(0, 512, value=160, step=8, label="min_new_tokens")
chat = gr.ChatInterface(
fn=lambda m, h: chat_fn(m, h, system_box.value, temperature.value, top_p.value, max_new.value, min_new.value),
title=None,
additional_inputs=[system_box, temperature, top_p, max_new, min_new],
retry_btn="Regenerate",
undo_btn="Undo Last",
clear_btn="Clear",
queue=True, # queue is recommended (and required for ZeroGPU concurrency)
)
if __name__ == "__main__":
demo.launch() |