File size: 4,457 Bytes
33dd5ba
 
8a5aaef
ccb51c9
26009d1
33dd5ba
 
 
 
318503e
33dd5ba
 
318503e
26009d1
f03b213
 
 
 
33dd5ba
 
 
 
 
 
 
 
 
 
26009d1
33dd5ba
f03b213
 
 
 
 
 
 
 
 
 
 
 
 
 
33dd5ba
26009d1
33dd5ba
 
 
 
26009d1
f03b213
 
33dd5ba
 
 
 
 
 
 
f03b213
33dd5ba
f03b213
 
 
c235810
 
 
 
 
 
 
 
 
 
a2b7239
 
 
 
34bd51b
a2b7239
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b934f15
a2b7239
 
 
 
 
 
 
34bd51b
33dd5ba
be1224a
a2b7239
 
 
 
 
 
be1224a
a2b7239
33dd5ba
a2b7239
 
 
 
 
 
 
 
 
c4bee80
26009d1
f03b213
a2b7239
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/merged_model")

# Load once (CPU until first call; device_map will move to GPU on first run)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
    trust_remote_code=True,
)

# Ensure pad token exists (many chat models reuse EOS as PAD)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token = tokenizer.eos_token

def _msgs_from_history(history, system_text):
    msgs = []
    if system_text:
        msgs.append({"role": "system", "content": system_text})
    for user, assistant in history:
        if user:
            msgs.append({"role": "user", "content": user})
        if assistant:
            msgs.append({"role": "assistant", "content": assistant})
    return msgs

def _eos_ids(tok):
    # Support ints/lists and optional <|im_end|>
    ids = set()
    if tok.eos_token_id is not None:
        if isinstance(tok.eos_token_id, (list, tuple)):
            ids.update(tok.eos_token_id)
        else:
            ids.add(tok.eos_token_id)
    try:
        im_end = tok.convert_tokens_to_ids("<|im_end|>")
        if im_end is not None and im_end != tok.unk_token_id:
            ids.add(im_end)
    except Exception:
        pass
    # Fallback: if still empty, just skip setting eos_token_id in GenerationConfig
    return list(ids)

def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
    msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
    prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

    eos = _eos_ids(tokenizer)
    gen_cfg_kwargs = dict(
        do_sample=True,
        temperature=float(temperature),
        top_p=float(top_p),
        max_new_tokens=int(max_new),
        min_new_tokens=int(min_new),
        repetition_penalty=1.02,
        no_repeat_ngram_size=3,
        pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
    )
    if eos:
        gen_cfg_kwargs["eos_token_id"] = eos

    gen_cfg = GenerationConfig(**gen_cfg_kwargs)

    with torch.no_grad():
        out = model.generate(**inputs, generation_config=gen_cfg)

    # slice off the prompt so we show only the assistant reply
    new_tokens = out[:, inputs["input_ids"].shape[1]:]
    reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
    return reply

@spaces.GPU()
def gradio_fn(message, history):
    response = infer_text(history + [(message, None)])
    return response

with gr.Blocks(css="""
    .gradio-container {
        max-width: 600px;
        margin: auto;
        padding: 20px;
        font-family: sans-serif;
        position: relative;
        }
    .chatbot {
        height: 500px !important;
        overflow-y: auto;
        }
    .corner {
       position: fixed;
       bottom: 2px;
       z-index: 9999;
       pointer-events: none;
        } 
    #left { left: 2px; }
    #right { right: 2px; }
    .corner img {
       height: 500px;  /* fixed height */
       width: auto;    /* auto to keep aspect ratio */
        }
    
    """) as demo:
    gr.Markdown(
    """
        <div style='text-align: center; padding: 10px;'>
        <h1 style='font-size: 2.2em; margin-bottom: 0.2em;'>🤖 <span style='color: #4F46E5;'>kRISHNA.ai</span></h1>
        <p style='font-size: 1.1em; color: #555;'>5000-Years of Ancient WISDOM with Modern AI ✨</p>
        </div>
    """,
    elem_id="header"
    )
    chat = gr.ChatInterface(
        fn=gradio_fn,
        examples=[
            "Hello!",
            "How can I overcome fear of failure?",
            "How do I forgive someone who hurt me deeply?",
            "What can I do to stop overthinking?"
        ],
        chatbot=gr.Chatbot(elem_classes="chatbot"),
        theme="compact",
    )
    gr.HTML(f"""
      <div id="left" class="corner">
        <img src="" alt="Arjun">
      </div>
      <div id="right" class="corner">
        <img src="" alt="Krishna">
      </div>
    """)


if __name__ == "__main__":
    demo.launch()