Sarathi.AI / app.py
JDhruv14's picture
Update app.py
e8c693f verified
raw
history blame
3.67 kB
import torch
torch._dynamo.config.disable = True
from collections.abc import Iterator
from transformers import (
Gemma3ForConditionalGeneration,
TextIteratorStreamer,
Gemma3Processor,
Gemma3nForConditionalGeneration,
)
import gradio as gr
import os
import spaces
# Load environment variables
model_3n_id = os.getenv("MODEL_3N_ID", "JDhruv14/merged_model")
# Load model and processor
model_3n = Gemma3nForConditionalGeneration.from_pretrained(
model_3n_id,
dtype=torch.bfloat16,
device_map="auto",
attn_implementation="eager"
)
input_processor = Gemma3Processor.from_pretrained(model_3n_id)
def infer_text(messages, max_new_tokens=300, temperature=1.0, top_p=0.95, top_k=64, repetition_penalty=1.1):
chat_template = []
for turn in messages:
if turn[0]:
chat_template.append({"role": "user", "content": [{"type": "text", "text": turn[0]}]})
if turn[1]:
chat_template.append({"role": "assistant", "content": [{"type": "text", "text": turn[1]}]})
chat_template.append({"role": "assistant", "content": [{"type": "text", "text": ""}]})
inputs = input_processor.apply_chat_template(
chat_template,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
).to(device=model_3n.device, dtype=torch.bfloat16)
with torch.no_grad():
output_tokens = model_3n.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
do_sample=True,
)
generated_text = input_processor.batch_decode(
output_tokens[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True
)[0]
return generated_text.strip()
@spaces.GPU()
def gradio_fn(message, history):
response = infer_text(history + [(message, None)])
return response
with gr.Blocks(css="""
.gradio-container {
max-width: 600px;
margin: auto;
padding: 20px;
font-family: sans-serif;
position: relative;
}
.chatbot {
height: 500px !important;
overflow-y: auto;
}
.corner {
position: fixed;
bottom: 2px;
z-index: 9999;
pointer-events: none;
}
#left { left: 2px; }
#right { right: 2px; }
.corner img {
height: 500px; /* fixed height */
width: auto; /* auto to keep aspect ratio */
}
""") as demo:
gr.Markdown(
"""
<div style='text-align: center; padding: 10px;'>
<h1 style='font-size: 2.2em; margin-bottom: 0.2em;'>πŸ€– <span style='color: #4F46E5;'>kRISHNA.ai</span></h1>
<p style='font-size: 1.1em; color: #555;'>5000-Years of Ancient WISDOM with Modern AI ✨</p>
</div>
""",
elem_id="header"
)
chat = gr.ChatInterface(
fn=gradio_fn,
examples=[
"Hello!",
"How can I overcome fear of failure?",
"How do I forgive someone who hurt me deeply?",
"What can I do to stop overthinking?"
],
chatbot=gr.Chatbot(elem_classes="chatbot"),
theme="compact",
)
gr.HTML(f"""
<div id="left" class="corner">
<img src="https://huggingface.co/spaces/p2kalita/kRISHNA.ai/resolve/main/assets/Arjun.png" alt="Arjun">
</div>
<div id="right" class="corner">
<img src="https://huggingface.co/spaces/p2kalita/kRISHNA.ai/resolve/main/assets/Krishna.png" alt="Krishna">
</div>
""")
if __name__ == "__main__":
demo.launch()