Spaces:
Runtime error
Runtime error
File size: 4,300 Bytes
33dd5ba 8a5aaef 33dd5ba 26009d1 33dd5ba 318503e 33dd5ba 318503e 26009d1 f03b213 33dd5ba 26009d1 33dd5ba f03b213 33dd5ba 26009d1 33dd5ba 26009d1 f03b213 33dd5ba f03b213 33dd5ba f03b213 33dd5ba 26009d1 33dd5ba 0d61b36 f03b213 33dd5ba e8c693f 33dd5ba 318503e f03b213 33dd5ba 318503e 33dd5ba e8c693f 33dd5ba be1224a f03b213 be1224a 33dd5ba 26009d1 f03b213 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/merged_model")
# Load once (CPU until first call; device_map will move to GPU on first run)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
trust_remote_code=True,
)
# Ensure pad token exists (many chat models reuse EOS as PAD)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
def _msgs_from_history(history, system_text):
msgs = []
if system_text:
msgs.append({"role": "system", "content": system_text})
for user, assistant in history:
if user:
msgs.append({"role": "user", "content": user})
if assistant:
msgs.append({"role": "assistant", "content": assistant})
return msgs
def _eos_ids(tok):
# Support ints/lists and optional <|im_end|>
ids = set()
if tok.eos_token_id is not None:
if isinstance(tok.eos_token_id, (list, tuple)):
ids.update(tok.eos_token_id)
else:
ids.add(tok.eos_token_id)
try:
im_end = tok.convert_tokens_to_ids("<|im_end|>")
if im_end is not None and im_end != tok.unk_token_id:
ids.add(im_end)
except Exception:
pass
# Fallback: if still empty, just skip setting eos_token_id in GenerationConfig
return list(ids)
def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
eos = _eos_ids(tokenizer)
gen_cfg_kwargs = dict(
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
max_new_tokens=int(max_new),
min_new_tokens=int(min_new),
repetition_penalty=1.02,
no_repeat_ngram_size=3,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
)
if eos:
gen_cfg_kwargs["eos_token_id"] = eos
gen_cfg = GenerationConfig(**gen_cfg_kwargs)
with torch.no_grad():
out = model.generate(**inputs, generation_config=gen_cfg)
# slice off the prompt so we show only the assistant reply
new_tokens = out[:, inputs["input_ids"].shape[1]:]
reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
return reply
# Wrap for ChatInterface + ZeroGPU
@spaces.GPU() # REQUIRED for ZeroGPU; remove if using standard GPU hardware
def gradio_fn(message, history, system_text, temperature, top_p, max_new, min_new):
return chat_fn(message, history, system_text, temperature, top_p, max_new, min_new)
with gr.Blocks() as demo:
gr.Markdown(
"<h1 style='text-align:center'>Gita Assistant (Qwen2.5-3B Fine-tuned)</h1>"
"<p style='text-align:center'>Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.</p>"
)
system_box = gr.Textbox(
value="Reply in the user’s language with 2–3 concise points (200–400 words); cite Gita verses when relevant.",
label="System prompt",
)
temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
max_new = gr.Slider(64, 1024, value=512, step=16, label="max_new_tokens")
min_new = gr.Slider(0, 512, value=160, step=8, label="min_new_tokens")
chat = gr.ChatInterface(
fn=gradio_fn,
additional_inputs=[system_box, temperature, top_p, max_new, min_new],
examples=[
"Hello!",
"How can I overcome fear of failure?",
"How do I forgive someone who hurt me deeply?",
"What can I do to stop overthinking?"
],
chatbot=gr.Chatbot(elem_classes="chatbot"),
)
if __name__ == "__main__":
demo.launch()
|