Sarathi.AI / app.py
JDhruv14's picture
Update app.py
c235810 verified
raw
history blame
5.04 kB
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/hello")
# Load once (CPU until first call; device_map will move to GPU on first run)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
trust_remote_code=True,
)
# Ensure pad token exists (many chat models reuse EOS as PAD)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
def _msgs_from_history(history, system_text):
msgs = []
if system_text:
msgs.append({"role": "system", "content": system_text})
for user, assistant in history:
if user:
msgs.append({"role": "user", "content": user})
if assistant:
msgs.append({"role": "assistant", "content": assistant})
return msgs
def _eos_ids(tok):
# Support ints/lists and optional <|im_end|>
ids = set()
if tok.eos_token_id is not None:
if isinstance(tok.eos_token_id, (list, tuple)):
ids.update(tok.eos_token_id)
else:
ids.add(tok.eos_token_id)
try:
im_end = tok.convert_tokens_to_ids("<|im_end|>")
if im_end is not None and im_end != tok.unk_token_id:
ids.add(im_end)
except Exception:
pass
# Fallback: if still empty, just skip setting eos_token_id in GenerationConfig
return list(ids)
def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
eos = _eos_ids(tokenizer)
gen_cfg_kwargs = dict(
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
max_new_tokens=int(max_new),
min_new_tokens=int(min_new),
repetition_penalty=1.02,
no_repeat_ngram_size=3,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
)
if eos:
gen_cfg_kwargs["eos_token_id"] = eos
gen_cfg = GenerationConfig(**gen_cfg_kwargs)
with torch.no_grad():
out = model.generate(**inputs, generation_config=gen_cfg)
# slice off the prompt so we show only the assistant reply
new_tokens = out[:, inputs["input_ids"].shape[1]:]
reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
return reply
def infer_text(history, system_text=""):
"""
Reply in the user’s language with 2–3 concise points (200–400 words); cite Gita verses when relevant.
"""
if not history:
return "" # nothing to answer
# Split out the newest user message and the prior turns
last_user_msg, _ = history[-1]
prior_history = history[:-1]
# Call your existing generator with sane defaults
return chat_fn(
message=last_user_msg,
history=prior_history,
system_text=system_text,
temperature=0.7,
top_p=0.9,
max_new=512,
min_new=128,
)
@spaces.GPU()
def gradio_fn(message, history):
response = infer_text(history + [(message, None)])
return response
with gr.Blocks(css="""
.gradio-container {
max-width: 600px;
margin: auto;
padding: 20px;
font-family: sans-serif;
position: relative;
}
.chatbot {
height: 500px !important;
overflow-y: auto;
}
.corner {
position: fixed;
bottom: 2px;
z-index: 9999;
pointer-events: none;
}
#left { left: 2px; }
#right { right: 2px; }
.corner img {
height: 500px; /* fixed height */
width: auto; /* auto to keep aspect ratio */
}
""") as demo:
gr.Markdown(
"""
<div style='text-align: center; padding: 10px;'>
<h1 style='font-size: 2.2em; margin-bottom: 0.2em;'>🤖 <span style='color: #4F46E5;'>kRISHNA.ai</span></h1>
<p style='font-size: 1.1em; color: #555;'>5000-Years of Ancient WISDOM with Modern AI ✨</p>
</div>
""",
elem_id="header"
)
chat = gr.ChatInterface(
fn=gradio_fn,
examples=[
"Hello!",
"How can I overcome fear of failure?",
"How do I forgive someone who hurt me deeply?",
"What can I do to stop overthinking?"
],
chatbot=gr.Chatbot(elem_classes="chatbot"),
theme="compact",
)
gr.HTML(f"""
<div id="left" class="corner">
<img src="">
</div>
<div id="right" class="corner">
<img src="">
</div>
""")
if __name__ == "__main__":
demo.launch()