Spaces:
Runtime error
Runtime error
File size: 5,227 Bytes
33dd5ba a2ef1b5 8a5aaef 748fdd1 26009d1 190ad71 9a2d448 33dd5ba 318503e 33dd5ba 318503e 26009d1 a2ef1b5 f03b213 33dd5ba 9a2d448 33dd5ba 26009d1 33dd5ba f03b213 33dd5ba 26009d1 33dd5ba 26009d1 f03b213 33dd5ba f03b213 33dd5ba f03b213 c235810 a2b7239 52ab581 9a2d448 52ab581 7f1e8f9 52ab581 3e25e48 a2b7239 3e25e48 2f09f40 3e25e48 2f09f40 3e25e48 2f09f40 3e25e48 35181f7 3e25e48 2f09f40 3e25e48 2f09f40 35181f7 c17be4a 3e25e48 3e2be9a 190ad71 5a30bf1 3e2be9a 2f09f40 3e2be9a 3e25e48 7d12e27 2f09f40 8db29b9 5a30bf1 2f09f40 7d12e27 3e2be9a 26009d1 2f09f40 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from datasets import load_dataset
MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/Qwen2.5-3B-Gita-FT")
GITA_SYSTEM_PROMPT = """You are Lord Krishna—the serene, compassionate teacher of the Bhagavad Gita."""
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
trust_remote_code=True,
)
dataset = load_dataset("JDhruv14/Bhagavad-Gita-QA")
# Ensure pad token exists (many chat models reuse EOS as PAD)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
tokenizer.pad_token = tokenizer.eos_token
def _msgs_from_history(history, system_text):
msgs = []
if system_text:
msgs.append({"role": "system", "content": system_text})
if not history:
return msgs
if isinstance(history[0], dict) and "role" in history[0] and "content" in history[0]:
for m in history:
role, content = m.get("role"), m.get("content")
if role in ("user", "assistant", "system") and content:
msgs.append({"role": role, "content": content})
else:
for user, assistant in history:
if user:
msgs.append({"role": "user", "content": user})
if assistant:
msgs.append({"role": "assistant", "content": assistant})
return msgs
def _eos_ids(tok):
ids = set()
if tok.eos_token_id is not None:
if isinstance(tok.eos_token_id, (list, tuple)):
ids.update(tok.eos_token_id)
else:
ids.add(tok.eos_token_id)
try:
im_end = tok.convert_tokens_to_ids("<|im_end|>")
if im_end is not None and im_end != tok.unk_token_id:
ids.add(im_end)
except Exception:
pass
return list(ids)
def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
eos = _eos_ids(tokenizer)
gen_cfg_kwargs = dict(
do_sample=True,
temperature=float(temperature),
top_p=float(top_p),
max_new_tokens=int(max_new),
min_new_tokens=int(min_new),
repetition_penalty=1.02,
no_repeat_ngram_size=3,
pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
)
if eos:
gen_cfg_kwargs["eos_token_id"] = eos
gen_cfg = GenerationConfig(**gen_cfg_kwargs)
with torch.no_grad():
out = model.generate(**inputs, generation_config=gen_cfg)
new_tokens = out[:, inputs["input_ids"].shape[1]:]
reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
return reply
@spaces.GPU()
def gradio_fn(message, history):
return chat_fn(
message=message,
history=history,
system_text=GITA_SYSTEM_PROMPT,
temperature=0.7,
top_p=0.95,
max_new=512,
min_new=0,
)
with gr.Blocks(css="""
:root { --chat-w: 520px; }
html, body {
height: 100%;
overflow-y: hidden; /* no page scroll */
margin: 0;
}
/* Full-screen background image with a soft dark overlay */
body {
background:
linear-gradient(0deg, rgba(0,0,0,.28), rgba(0,0,0,.28)),
url("https://huggingface.co/spaces/JDhruv14/gita/resolve/main/bg.jpg") center / cover no-repeat fixed; /* <- change filename if needed */
}
/* Left-aligned, narrower chat panel */
.gradio-container {
max-width: var(--chat-w);
width: var(--chat-w);
margin-left: 16px;
margin-right: auto;
padding: 20px;
font-family: sans-serif;
position: relative;
/* optional glass effect for readability */
background: rgba(0,0,0,.30);
border-radius: 16px;
backdrop-filter: blur(6px);
}
.chatbot {
height: 480px !important;
overflow-y: auto;
}
@media (max-width: 720px){
:root { --chat-w: 92vw; }
.gradio-container { margin-left: 4vw; }
}
""") as demo:
gr.Markdown(
"""
<div style='text-align: center; padding: 10px;'>
<h1 style='font-size: 2.0em; margin-bottom: 0.2em;'><span style='color: #4F46E5;'>🪷 Sarathi.AI</span></h1>
<p style='font-size: 1.0em; color: #bbb;'>Gita’s Eternal Teachings, Guided by AI 🕉️</p>
</div>
""",
elem_id="header"
)
gr.ChatInterface(
fn=gradio_fn,
examples=[
"Namaste!",
"What is my duty?",
"What is a Guna?",
"What can I do to stop overthinking?"
],
chatbot=gr.Chatbot(type="messages", elem_classes="chatbot"),
type="messages",
)
if __name__ == "__main__":
demo.launch() |