import os, torch, gradio as gr, spaces from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/Gita-FT-v2-Qwen2.5-3B") # Load once (CPU until first call; device_map will move to GPU on first run) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto", trust_remote_code=True, ) def _msgs_from_history(history, system_text): msgs = [] if system_text: msgs.append({"role": "system", "content": system_text}) for user, assistant in history: if user: msgs.append({"role": "user", "content": user}) if assistant: msgs.append({"role": "assistant", "content": assistant}) return msgs def _eos_ids(tok): ids = {tok.eos_token_id} im_end = tok.convert_tokens_to_ids("<|im_end|>") if im_end is not None: ids.add(im_end) return list(ids) @spaces.GPU(duration=120) # REQUIRED for ZeroGPU; remove if using standard GPU hardware def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new): msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}] prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) inputs = tokenizer([prompt], return_tensors="pt").to(model.device) gen_cfg = GenerationConfig( do_sample=True, temperature=float(temperature), top_p=float(top_p), max_new_tokens=int(max_new), min_new_tokens=int(min_new), repetition_penalty=1.02, no_repeat_ngram_size=3, eos_token_id=_eos_ids(tokenizer), pad_token_id=tokenizer.eos_token_id, ) with torch.no_grad(): out = model.generate(**inputs, generation_config=gen_cfg) # slice off the prompt so we show only the assistant reply new_tokens = out[:, inputs["input_ids"].shape[1]:] reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip() return reply with gr.Blocks() as demo: gr.Markdown( "

Gita Assistant (Qwen2.5-3B Fine-tuned)

" "

Ask in English / हिंदी / ગુજરાતી. The assistant cites verses when relevant.

" ) system_box = gr.Textbox( value="Reply in the user’s language with 2–3 concise points (200–400 words); cite Gita verses when relevant.", label="System prompt", ) temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature") top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p") max_new = gr.Slider(64, 1024, value=512, step=16, label="max_new_tokens") min_new = gr.Slider(0, 512, value=160, step=8, label="min_new_tokens") chat = gr.ChatInterface( fn=lambda m, h: chat_fn(m, h, system_box.value, temperature.value, top_p.value, max_new.value, min_new.value), title=None, additional_inputs=[system_box, temperature, top_p, max_new, min_new], retry_btn="Regenerate", undo_btn="Undo Last", clear_btn="Clear", queue=True, # queue is recommended (and required for ZeroGPU concurrency) ) if __name__ == "__main__": demo.launch()