Sarathi.AI / app.py
JDhruv14's picture
Update app.py
c17be4a verified
raw
history blame
4.91 kB
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/fullnfinal")
# --- System prompt (Gita persona) ---
GITA_SYSTEM_PROMPT = """You are Krishna β€” a compassionate, serene, and practical guider inspired by the Bhagavad Gita"""
# Load once (CPU until first call; device_map will move to GPU on first run)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
trust_remote_code=True,
)
# Ensure pad token exists (many chat models reuse EOS as PAD)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
def _msgs_from_history(history, system_text):
msgs = []
if system_text:
msgs.append({"role": "system", "content": system_text})
if not history:
return msgs
# Support both new "messages" format and legacy (user, assistant) tuples
if isinstance(history[0], dict) and "role" in history[0] and "content" in history[0]:
for m in history:
role, content = m.get("role"), m.get("content")
if role in ("user", "assistant", "system") and content:
msgs.append({"role": role, "content": content})
else:
for user, assistant in history:
if user:
msgs.append({"role": "user", "content": user})
if assistant:
msgs.append({"role": "assistant", "content": assistant})
return msgs
def _eos_ids(tok):
# Support ints/lists and optional <|im_end|>
ids = set()
if tok.eos_token_id is not None:
if isinstance(tok.eos_token_id, (list, tuple)):
ids.update(tok.eos_token_id)
else:
ids.add(tok.eos_token_id)
try:
im_end = tok.convert_tokens_to_ids("<|im_end|>")
if im_end is not None and im_end != tok.unk_token_id:
ids.add(im_end)
except Exception:
pass
return list(ids)
def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
eos = _eos_ids(tokenizer)
gen_cfg_kwargs = dict(
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
max_new_tokens=int(max_new),
min_new_tokens=int(min_new),
repetition_penalty=1.02,
no_repeat_ngram_size=3,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
)
if eos:
gen_cfg_kwargs["eos_token_id"] = eos
gen_cfg = GenerationConfig(**gen_cfg_kwargs)
with torch.no_grad():
out = model.generate(**inputs, generation_config=gen_cfg)
new_tokens = out[:, inputs["input_ids"].shape[1]:]
reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
return reply
@spaces.GPU()
def gradio_fn(message, history):
# Inject the Gita system prompt here
return chat_fn(
message=message,
history=history,
system_text=GITA_SYSTEM_PROMPT,
temperature=0.7,
top_p=0.95,
max_new=512,
min_new=0,
)
with gr.Blocks(css="""
:root { --corner-h: min(65vh, 820px); } /* bigger images, size caps to viewport */
html, body { overflow-y: hidden; } /* no vertical scroll */
.gradio-container {
max-width: 600px;
margin: auto;
padding: 20px;
font-family: sans-serif;
position: relative;
}
.chatbot { height: 500px !important; overflow-y: auto; }
.corner {
position: fixed;
z-index: 9999;
pointer-events: none;
}
#left { left: 8px; bottom: 8px; } /* Arjuna pinned to bottom-left */
#right { right: 8px; bottom: 88px; } /* Krishna lifted ~80–90px */
.corner img {
height: var(--corner-h);
width: auto;
display: block;
}
/* safety on short screens: scale down a bit so nothing gets cramped */
@media (max-height: 740px) {
:root { --corner-h: 50vh; }
#right { bottom: 72px; }
}
""") as demo:
...
gr.HTML("""
<div id="left" class="corner">
<img src="https://huggingface.co/spaces/JDhruv14/gita/resolve/main/arjuna.png" alt="Arjun">
</div>
<div id="right" class="corner">
<img src="https://huggingface.co/spaces/JDhruv14/gita/resolve/main/krishna.png" alt="Krishna">
</div>
""")
if __name__ == "__main__":
demo.launch()