File size: 5,227 Bytes
33dd5ba
 
a2ef1b5
8a5aaef
748fdd1
26009d1
190ad71
9a2d448
33dd5ba
 
 
318503e
33dd5ba
 
318503e
26009d1
a2ef1b5
f03b213
 
 
 
33dd5ba
 
 
 
9a2d448
 
 
 
 
 
 
 
 
 
 
 
 
 
33dd5ba
26009d1
33dd5ba
f03b213
 
 
 
 
 
 
 
 
 
 
 
33dd5ba
26009d1
33dd5ba
 
 
 
26009d1
f03b213
 
33dd5ba
 
 
 
 
 
 
f03b213
33dd5ba
f03b213
 
 
c235810
 
 
 
 
 
 
 
 
a2b7239
52ab581
 
 
 
9a2d448
52ab581
 
 
7f1e8f9
52ab581
3e25e48
a2b7239
3e25e48
2f09f40
 
 
 
 
 
3e25e48
 
 
2f09f40
3e25e48
2f09f40
3e25e48
 
35181f7
3e25e48
 
 
 
 
2f09f40
3e25e48
 
 
 
2f09f40
 
 
 
 
 
 
35181f7
c17be4a
3e25e48
3e2be9a
 
190ad71
5a30bf1
 
3e2be9a
2f09f40
 
3e2be9a
 
3e25e48
7d12e27
2f09f40
8db29b9
 
5a30bf1
2f09f40
 
 
7d12e27
 
3e2be9a
26009d1
2f09f40
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from datasets import load_dataset

MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/Qwen2.5-3B-Gita-FT")

GITA_SYSTEM_PROMPT = """You are Lord Krishna—the serene, compassionate teacher of the Bhagavad Gita."""

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
    trust_remote_code=True,
)

dataset = load_dataset("JDhruv14/Bhagavad-Gita-QA")
# Ensure pad token exists (many chat models reuse EOS as PAD)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token = tokenizer.eos_token

def _msgs_from_history(history, system_text):
    msgs = []
    if system_text:
        msgs.append({"role": "system", "content": system_text})
    if not history:
        return msgs

    if isinstance(history[0], dict) and "role" in history[0] and "content" in history[0]:
        for m in history:
            role, content = m.get("role"), m.get("content")
            if role in ("user", "assistant", "system") and content:
                msgs.append({"role": role, "content": content})
    else:
        for user, assistant in history:
            if user:
                msgs.append({"role": "user", "content": user})
            if assistant:
                msgs.append({"role": "assistant", "content": assistant})
    return msgs

def _eos_ids(tok):
    ids = set()
    if tok.eos_token_id is not None:
        if isinstance(tok.eos_token_id, (list, tuple)):
            ids.update(tok.eos_token_id)
        else:
            ids.add(tok.eos_token_id)
    try:
        im_end = tok.convert_tokens_to_ids("<|im_end|>")
        if im_end is not None and im_end != tok.unk_token_id:
            ids.add(im_end)
    except Exception:
        pass
    return list(ids)

def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
    msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
    prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

    eos = _eos_ids(tokenizer)
    gen_cfg_kwargs = dict(
        do_sample=True,
        temperature=float(temperature),
        top_p=float(top_p),
        max_new_tokens=int(max_new),
        min_new_tokens=int(min_new),
        repetition_penalty=1.02,
        no_repeat_ngram_size=3,
        pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
    )
    if eos:
        gen_cfg_kwargs["eos_token_id"] = eos

    gen_cfg = GenerationConfig(**gen_cfg_kwargs)

    with torch.no_grad():
        out = model.generate(**inputs, generation_config=gen_cfg)

    new_tokens = out[:, inputs["input_ids"].shape[1]:]
    reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
    return reply

@spaces.GPU()
def gradio_fn(message, history):
    return chat_fn(
        message=message,
        history=history,
        system_text=GITA_SYSTEM_PROMPT,
        temperature=0.7,
        top_p=0.95,
        max_new=512,
        min_new=0,
    )

with gr.Blocks(css="""
    :root { --chat-w: 520px; }
    html, body {
        height: 100%;
        overflow-y: hidden;               /* no page scroll */
        margin: 0;
    }
    /* Full-screen background image with a soft dark overlay */
    body {
        background:
          linear-gradient(0deg, rgba(0,0,0,.28), rgba(0,0,0,.28)),
          url("https://huggingface.co/spaces/JDhruv14/gita/resolve/main/bg.jpg") center / cover no-repeat fixed;  /* <- change filename if needed */
    }
    /* Left-aligned, narrower chat panel */
    .gradio-container {
        max-width: var(--chat-w);
        width: var(--chat-w);
        margin-left: 16px;
        margin-right: auto;
        padding: 20px;
        font-family: sans-serif;
        position: relative;
        /* optional glass effect for readability */
        background: rgba(0,0,0,.30);
        border-radius: 16px;
        backdrop-filter: blur(6px);
    }
    .chatbot {
        height: 480px !important;
        overflow-y: auto;
    }
    @media (max-width: 720px){
        :root { --chat-w: 92vw; }
        .gradio-container { margin-left: 4vw; }
    }
""") as demo:

    gr.Markdown(
        """
        <div style='text-align: center; padding: 10px;'>
            <h1 style='font-size: 2.0em; margin-bottom: 0.2em;'><span style='color: #4F46E5;'>🪷 Sarathi.AI</span></h1>
            <p style='font-size: 1.0em; color: #bbb;'>Gita’s Eternal Teachings, Guided by AI 🕉️</p>
        </div>
        """,
        elem_id="header"
    )

    gr.ChatInterface(
        fn=gradio_fn,
        examples=[
            "Namaste!",
            "What is my duty?",
            "What is a Guna?",
            "What can I do to stop overthinking?"
        ],
        chatbot=gr.Chatbot(type="messages", elem_classes="chatbot"),
        type="messages",
    )

if __name__ == "__main__":
    demo.launch()