import torch torch._dynamo.config.disable = True from collections.abc import Iterator from transformers import ( Gemma3ForConditionalGeneration, TextIteratorStreamer, Gemma3Processor, Gemma3nForConditionalGeneration, ) import gradio as gr import os import spaces # Load environment variables model_3n_id = os.getenv("MODEL_3N_ID", "JDhruv14/merged_model") # Load model and processor model_3n = Gemma3nForConditionalGeneration.from_pretrained( model_3n_id, dtype=torch.bfloat16, device_map="auto", attn_implementation="eager" ) input_processor = Gemma3Processor.from_pretrained(model_3n_id) def infer_text(messages, max_new_tokens=300, temperature=1.0, top_p=0.95, top_k=64, repetition_penalty=1.1): chat_template = [] for turn in messages: if turn[0]: chat_template.append({"role": "user", "content": [{"type": "text", "text": turn[0]}]}) if turn[1]: chat_template.append({"role": "assistant", "content": [{"type": "text", "text": turn[1]}]}) chat_template.append({"role": "assistant", "content": [{"type": "text", "text": ""}]}) inputs = input_processor.apply_chat_template( chat_template, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(device=model_3n.device, dtype=torch.bfloat16) with torch.no_grad(): output_tokens = model_3n.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, do_sample=True, ) generated_text = input_processor.batch_decode( output_tokens[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True )[0] return generated_text.strip() @spaces.GPU() def gradio_fn(message, history): response = infer_text(history + [(message, None)]) return response with gr.Blocks(css=""" .gradio-container { max-width: 600px; margin: auto; padding: 20px; font-family: sans-serif; position: relative; } .chatbot { height: 500px !important; overflow-y: auto; } .corner { position: fixed; bottom: 2px; z-index: 9999; pointer-events: none; } #left { left: 2px; } #right { right: 2px; } .corner img { height: 500px; /* fixed height */ width: auto; /* auto to keep aspect ratio */ } """) as demo: gr.Markdown( """

🤖 kRISHNA.ai

5000-Years of Ancient WISDOM with Modern AI ✨

""", elem_id="header" ) chat = gr.ChatInterface( fn=gradio_fn, examples=[ "Hello!", "How can I overcome fear of failure?", "How do I forgive someone who hurt me deeply?", "What can I do to stop overthinking?" ], chatbot=gr.Chatbot(elem_classes="chatbot"), theme="compact", ) gr.HTML(f"""
Arjun
""") if __name__ == "__main__": demo.launch()