File size: 5,705 Bytes
33dd5ba
 
8a5aaef
ccb51c9
26009d1
9a2d448
 
 
 
 
 
 
 
33dd5ba
 
 
 
318503e
33dd5ba
 
318503e
26009d1
f03b213
 
 
 
33dd5ba
 
 
 
9a2d448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33dd5ba
26009d1
33dd5ba
f03b213
 
 
 
 
 
 
 
 
 
 
 
 
33dd5ba
26009d1
33dd5ba
 
 
 
26009d1
f03b213
 
33dd5ba
 
 
 
 
 
 
f03b213
33dd5ba
f03b213
 
 
c235810
 
 
 
 
 
 
 
 
a2b7239
52ab581
9a2d448
52ab581
 
 
9a2d448
52ab581
 
 
 
 
34bd51b
a2b7239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b934f15
a2b7239
 
9a2d448
a2b7239
 
 
 
34bd51b
33dd5ba
be1224a
a2b7239
 
 
 
 
 
9a2d448
 
33dd5ba
a2b7239
 
 
 
 
 
 
 
 
26009d1
f03b213
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/merged_model")

# --- System prompt (Gita persona) ---
GITA_SYSTEM_PROMPT = """You are KRISHNA.ai — a compassionate, serene, and practical guide inspired by the Bhagavad Gita.
Style: calm, clear, inclusive, and down-to-earth. Use everyday language, avoid jargon.
When fitting, quote a brief shloka with Chapter:Verse (e.g., 2:47) and give a one-line meaning. Do not over-quote.
Emphasize: selfless action (karma-yoga), equanimity, disciplined mind, devotion, and wisdom — applicable to modern life.
Be non-sectarian and respectful of all beliefs. If a topic is clinical/medical/legal, gently suggest professional help.
Prefer concise replies (5–10 sentences). Use short steps/bullets for “how-to” answers. End with a one-line “Essence:” summary when helpful."""

# Load once (CPU until first call; device_map will move to GPU on first run)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
    trust_remote_code=True,
)

# Ensure pad token exists (many chat models reuse EOS as PAD)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token = tokenizer.eos_token

def _msgs_from_history(history, system_text):
    msgs = []
    if system_text:
        msgs.append({"role": "system", "content": system_text})
    if not history:
        return msgs

    # Support both new "messages" format and legacy (user, assistant) tuples
    if isinstance(history[0], dict) and "role" in history[0] and "content" in history[0]:
        for m in history:
            role, content = m.get("role"), m.get("content")
            if role in ("user", "assistant", "system") and content:
                msgs.append({"role": role, "content": content})
    else:
        for user, assistant in history:
            if user:
                msgs.append({"role": "user", "content": user})
            if assistant:
                msgs.append({"role": "assistant", "content": assistant})
    return msgs

def _eos_ids(tok):
    # Support ints/lists and optional <|im_end|>
    ids = set()
    if tok.eos_token_id is not None:
        if isinstance(tok.eos_token_id, (list, tuple)):
            ids.update(tok.eos_token_id)
        else:
            ids.add(tok.eos_token_id)
    try:
        im_end = tok.convert_tokens_to_ids("<|im_end|>")
        if im_end is not None and im_end != tok.unk_token_id:
            ids.add(im_end)
    except Exception:
        pass
    return list(ids)

def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
    msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
    prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

    eos = _eos_ids(tokenizer)
    gen_cfg_kwargs = dict(
        do_sample=True,
        temperature=float(temperature),
        top_p=float(top_p),
        max_new_tokens=int(max_new),
        min_new_tokens=int(min_new),
        repetition_penalty=1.02,
        no_repeat_ngram_size=3,
        pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
    )
    if eos:
        gen_cfg_kwargs["eos_token_id"] = eos

    gen_cfg = GenerationConfig(**gen_cfg_kwargs)

    with torch.no_grad():
        out = model.generate(**inputs, generation_config=gen_cfg)

    new_tokens = out[:, inputs["input_ids"].shape[1]:]
    reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
    return reply

@spaces.GPU()
def gradio_fn(message, history):
    # Inject the Gita system prompt here
    return chat_fn(
        message=message,
        history=history,
        system_text=GITA_SYSTEM_PROMPT,
        temperature=0.7,
        top_p=0.95,
        max_new=512,
        min_new=0,
    )

with gr.Blocks(css="""
    .gradio-container {
        max-width: 600px;
        margin: auto;
        padding: 20px;
        font-family: sans-serif;
        position: relative;
        }
    .chatbot {
        height: 500px !important;
        overflow-y: auto;
        }
    .corner {
       position: fixed;
       bottom: 2px;
       z-index: 9999;
       pointer-events: none;
        } 
    #left { left: 2px; }
    #right { right: 2px; }
    .corner img {
       height: 500px;  /* fixed height */
       width: auto;    /* auto to keep aspect ratio */
        }
    
    """) as demo:
    gr.Markdown(
    """
        <div style='text-align: center; padding: 10px;'>
        <h1 style='font-size: 2.2em; margin-bottom: 0.2em;'><span style='color: #4F46E5;'>kRISHNA.ai</span></h1>
        <p style='font-size: 1.1em; color: #555;'>5000-Years of Ancient WISDOM with Modern AI ✨</p>
        </div>
    """,
    elem_id="header"
    )
    chat = gr.ChatInterface(
        fn=gradio_fn,
        examples=[
            "Hello!",
            "How can I overcome fear of failure?",
            "How do I forgive someone who hurt me deeply?",
            "What can I do to stop overthinking?"
        ],
        chatbot=gr.Chatbot(type="messages", elem_classes="chatbot"),
        type="messages",
    )
    gr.HTML(f"""
      <div id="left" class="corner">
        <img src="" alt="Arjun">
      </div>
      <div id="right" class="corner">
        <img src="" alt="Krishna">
      </div>
    """)

if __name__ == "__main__":
    demo.launch()