Spaces:
Runtime error
Runtime error
File size: 3,668 Bytes
e8c693f 8a5aaef e8c693f 26009d1 e8c693f 318503e e8c693f 318503e e8c693f 26009d1 e8c693f 26009d1 e8c693f 26009d1 318503e e8c693f 26009d1 e8c693f 26009d1 e8c693f 0d61b36 e8c693f 318503e e8c693f 0d61b36 e8c693f 0d61b36 e8c693f 318503e e8c693f 26009d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import torch
torch._dynamo.config.disable = True
from collections.abc import Iterator
from transformers import (
Gemma3ForConditionalGeneration,
TextIteratorStreamer,
Gemma3Processor,
Gemma3nForConditionalGeneration,
)
import gradio as gr
import os
import spaces
# Load environment variables
model_3n_id = os.getenv("MODEL_3N_ID", "JDhruv14/merged_model")
# Load model and processor
model_3n = Gemma3nForConditionalGeneration.from_pretrained(
model_3n_id,
dtype=torch.bfloat16,
device_map="auto",
attn_implementation="eager"
)
input_processor = Gemma3Processor.from_pretrained(model_3n_id)
def infer_text(messages, max_new_tokens=300, temperature=1.0, top_p=0.95, top_k=64, repetition_penalty=1.1):
chat_template = []
for turn in messages:
if turn[0]:
chat_template.append({"role": "user", "content": [{"type": "text", "text": turn[0]}]})
if turn[1]:
chat_template.append({"role": "assistant", "content": [{"type": "text", "text": turn[1]}]})
chat_template.append({"role": "assistant", "content": [{"type": "text", "text": ""}]})
inputs = input_processor.apply_chat_template(
chat_template,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(device=model_3n.device, dtype=torch.bfloat16)
with torch.no_grad():
output_tokens = model_3n.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=True,
)
generated_text = input_processor.batch_decode(
output_tokens[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True
)[0]
return generated_text.strip()
@spaces.GPU()
def gradio_fn(message, history):
response = infer_text(history + [(message, None)])
return response
with gr.Blocks(css="""
.gradio-container {
max-width: 600px;
margin: auto;
padding: 20px;
font-family: sans-serif;
position: relative;
}
.chatbot {
height: 500px !important;
overflow-y: auto;
}
.corner {
position: fixed;
bottom: 2px;
z-index: 9999;
pointer-events: none;
}
#left { left: 2px; }
#right { right: 2px; }
.corner img {
height: 500px; /* fixed height */
width: auto; /* auto to keep aspect ratio */
}
""") as demo:
gr.Markdown(
"""
<div style='text-align: center; padding: 10px;'>
<h1 style='font-size: 2.2em; margin-bottom: 0.2em;'>🤖 <span style='color: #4F46E5;'>kRISHNA.ai</span></h1>
<p style='font-size: 1.1em; color: #555;'>5000-Years of Ancient WISDOM with Modern AI ✨</p>
</div>
""",
elem_id="header"
)
chat = gr.ChatInterface(
fn=gradio_fn,
examples=[
"Hello!",
"How can I overcome fear of failure?",
"How do I forgive someone who hurt me deeply?",
"What can I do to stop overthinking?"
],
chatbot=gr.Chatbot(elem_classes="chatbot"),
theme="compact",
)
gr.HTML(f"""
<div id="left" class="corner">
<img src="https://huggingface.co/spaces/p2kalita/kRISHNA.ai/resolve/main/assets/Arjun.png" alt="Arjun">
</div>
<div id="right" class="corner">
<img src="https://huggingface.co/spaces/p2kalita/kRISHNA.ai/resolve/main/assets/Krishna.png" alt="Krishna">
</div>
""")
if __name__ == "__main__":
demo.launch()
|