import os, torch, gradio as gr, spaces from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/fullnfinal") # --- System prompt (Gita persona) --- GITA_SYSTEM_PROMPT = """You are Krishna — a compassionate, serene, and practical guider inspired by the Bhagavad Gita""" # Load once (CPU until first call; device_map will move to GPU on first run) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto", trust_remote_code=True, ) # Ensure pad token exists (many chat models reuse EOS as PAD) if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: tokenizer.pad_token = tokenizer.eos_token def _msgs_from_history(history, system_text): msgs = [] if system_text: msgs.append({"role": "system", "content": system_text}) if not history: return msgs # Support both new "messages" format and legacy (user, assistant) tuples if isinstance(history[0], dict) and "role" in history[0] and "content" in history[0]: for m in history: role, content = m.get("role"), m.get("content") if role in ("user", "assistant", "system") and content: msgs.append({"role": role, "content": content}) else: for user, assistant in history: if user: msgs.append({"role": "user", "content": user}) if assistant: msgs.append({"role": "assistant", "content": assistant}) return msgs def _eos_ids(tok): # Support ints/lists and optional <|im_end|> ids = set() if tok.eos_token_id is not None: if isinstance(tok.eos_token_id, (list, tuple)): ids.update(tok.eos_token_id) else: ids.add(tok.eos_token_id) try: im_end = tok.convert_tokens_to_ids("<|im_end|>") if im_end is not None and im_end != tok.unk_token_id: ids.add(im_end) except Exception: pass return list(ids) def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new): msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}] prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) inputs = tokenizer([prompt], return_tensors="pt").to(model.device) eos = _eos_ids(tokenizer) gen_cfg_kwargs = dict( do_sample=True, temperature=float(temperature), top_p=float(top_p), max_new_tokens=int(max_new), min_new_tokens=int(min_new), repetition_penalty=1.02, no_repeat_ngram_size=3, pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, ) if eos: gen_cfg_kwargs["eos_token_id"] = eos gen_cfg = GenerationConfig(**gen_cfg_kwargs) with torch.no_grad(): out = model.generate(**inputs, generation_config=gen_cfg) new_tokens = out[:, inputs["input_ids"].shape[1]:] reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip() return reply @spaces.GPU() def gradio_fn(message, history): # Inject the Gita system prompt here return chat_fn( message=message, history=history, system_text=GITA_SYSTEM_PROMPT, temperature=0.7, top_p=0.95, max_new=512, min_new=0, ) with gr.Blocks(css=""" .gradio-container { max-width: 600px; margin: auto; padding: 20px; font-family: sans-serif; position: relative; } .chatbot { height: 500px !important; overflow-y: auto; } .corner { position: fixed; bottom: 2px; z-index: 9999; pointer-events: none; } #left { left: 2px; } #right { right: 2px; } .corner img { height: 500px; /* fixed height */ width: auto; /* auto to keep aspect ratio */ } """) as demo: gr.Markdown( """

kRISHNA.ai

5000-Years of Ancient WISDOM with Modern AI ✨

""", elem_id="header" ) chat = gr.ChatInterface( fn=gradio_fn, examples=[ "Hello!", "How can I overcome fear of failure?", "How do I forgive someone who hurt me deeply?", "What can I do to stop overthinking?" ], chatbot=gr.Chatbot(type="messages", elem_classes="chatbot"), type="messages", ) gr.HTML(f"""
Arjun
""") if __name__ == "__main__": demo.launch()