File size: 5,042 Bytes
33dd5ba
 
8a5aaef
06e919e
26009d1
33dd5ba
 
 
 
318503e
33dd5ba
 
318503e
26009d1
f03b213
 
 
 
33dd5ba
 
 
 
 
 
 
 
 
 
26009d1
33dd5ba
f03b213
 
 
 
 
 
 
 
 
 
 
 
 
 
33dd5ba
26009d1
33dd5ba
 
 
 
26009d1
f03b213
 
33dd5ba
 
 
 
 
 
 
f03b213
33dd5ba
f03b213
 
 
c235810
 
 
 
 
 
 
 
 
 
0647748
34bd51b
8759ca2
34bd51b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1dffd7
497537b
34bd51b
 
c4bee80
 
 
 
f03b213
c4bee80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8c693f
c4bee80
 
 
 
 
 
 
318503e
33dd5ba
be1224a
c4bee80
 
 
 
 
 
be1224a
c4bee80
33dd5ba
c4bee80
 
b961dd3
c4bee80
 
b961dd3
c4bee80
 
 
26009d1
f03b213
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/hello")

# Load once (CPU until first call; device_map will move to GPU on first run)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
    trust_remote_code=True,
)

# Ensure pad token exists (many chat models reuse EOS as PAD)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token = tokenizer.eos_token

def _msgs_from_history(history, system_text):
    msgs = []
    if system_text:
        msgs.append({"role": "system", "content": system_text})
    for user, assistant in history:
        if user:
            msgs.append({"role": "user", "content": user})
        if assistant:
            msgs.append({"role": "assistant", "content": assistant})
    return msgs

def _eos_ids(tok):
    # Support ints/lists and optional <|im_end|>
    ids = set()
    if tok.eos_token_id is not None:
        if isinstance(tok.eos_token_id, (list, tuple)):
            ids.update(tok.eos_token_id)
        else:
            ids.add(tok.eos_token_id)
    try:
        im_end = tok.convert_tokens_to_ids("<|im_end|>")
        if im_end is not None and im_end != tok.unk_token_id:
            ids.add(im_end)
    except Exception:
        pass
    # Fallback: if still empty, just skip setting eos_token_id in GenerationConfig
    return list(ids)

def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
    msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
    prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

    eos = _eos_ids(tokenizer)
    gen_cfg_kwargs = dict(
        do_sample=True,
        temperature=float(temperature),
        top_p=float(top_p),
        max_new_tokens=int(max_new),
        min_new_tokens=int(min_new),
        repetition_penalty=1.02,
        no_repeat_ngram_size=3,
        pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
    )
    if eos:
        gen_cfg_kwargs["eos_token_id"] = eos

    gen_cfg = GenerationConfig(**gen_cfg_kwargs)

    with torch.no_grad():
        out = model.generate(**inputs, generation_config=gen_cfg)

    # slice off the prompt so we show only the assistant reply
    new_tokens = out[:, inputs["input_ids"].shape[1]:]
    reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
    return reply

def infer_text(history, system_text=""):
    """
    Reply in the user’s language with 2–3 concise points (200–400 words); cite Gita verses when relevant.
    """
    if not history:
        return ""  # nothing to answer

    # Split out the newest user message and the prior turns
    last_user_msg, _ = history[-1]
    prior_history = history[:-1]

    # Call your existing generator with sane defaults
    return chat_fn(
        message=last_user_msg,
        history=prior_history,
        system_text=system_text,
        temperature=0.7,
        top_p=0.9,
        max_new=512,
        min_new=128,
    )

@spaces.GPU()
def gradio_fn(message, history):
    response = infer_text(history + [(message, None)])
    return response

with gr.Blocks(css="""
    .gradio-container {
        max-width: 600px;
        margin: auto;
        padding: 20px;
        font-family: sans-serif;
        position: relative;
        }
    .chatbot {
        height: 500px !important;
        overflow-y: auto;
        }
    .corner {
       position: fixed;
       bottom: 2px;
       z-index: 9999;
       pointer-events: none;
        } 
    #left { left: 2px; }
    #right { right: 2px; }
    .corner img {
       height: 500px;  /* fixed height */
       width: auto;    /* auto to keep aspect ratio */
        }
    
    """) as demo:
    gr.Markdown(
    """
        <div style='text-align: center; padding: 10px;'>
        <h1 style='font-size: 2.2em; margin-bottom: 0.2em;'>🤖 <span style='color: #4F46E5;'>kRISHNA.ai</span></h1>
        <p style='font-size: 1.1em; color: #555;'>5000-Years of Ancient WISDOM with Modern AI ✨</p>
        </div>
    """,
    elem_id="header"
    )
    chat = gr.ChatInterface(
        fn=gradio_fn,
        examples=[
            "Hello!",
            "How can I overcome fear of failure?",
            "How do I forgive someone who hurt me deeply?",
            "What can I do to stop overthinking?"
        ],
        chatbot=gr.Chatbot(elem_classes="chatbot"),
        theme="compact",
    )
    gr.HTML(f"""
      <div id="left" class="corner">
        <img src="">
      </div>
      <div id="right" class="corner">
        <img src="">
      </div>
    """)

if __name__ == "__main__":
    demo.launch()