File size: 4,906 Bytes
33dd5ba
 
8a5aaef
0100e35
26009d1
9a2d448
fb06326
9a2d448
33dd5ba
 
 
 
318503e
33dd5ba
 
318503e
26009d1
f03b213
 
 
 
33dd5ba
 
 
 
9a2d448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33dd5ba
26009d1
33dd5ba
f03b213
 
 
 
 
 
 
 
 
 
 
 
 
33dd5ba
26009d1
33dd5ba
 
 
 
26009d1
f03b213
 
33dd5ba
 
 
 
 
 
 
f03b213
33dd5ba
f03b213
 
 
c235810
 
 
 
 
 
 
 
 
a2b7239
52ab581
9a2d448
52ab581
 
 
9a2d448
52ab581
 
 
7f1e8f9
52ab581
34bd51b
a2b7239
c17be4a
 
 
a2b7239
 
 
 
 
 
c17be4a
 
 
 
a2b7239
c17be4a
 
 
 
 
 
 
a2b7239
c17be4a
 
 
 
 
 
 
 
 
 
 
 
 
a2b7239
5991764
a2b7239
 
5991764
a2b7239
 
 
c17be4a
26009d1
f03b213
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os, torch, gradio as gr, spaces
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig

MODEL_ID = os.getenv("MODEL_ID", "JDhruv14/fullnfinal")

# --- System prompt (Gita persona) ---
GITA_SYSTEM_PROMPT = """You are Krishna β€” a compassionate, serene, and practical guider inspired by the Bhagavad Gita"""

# Load once (CPU until first call; device_map will move to GPU on first run)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else "auto",
    trust_remote_code=True,
)

# Ensure pad token exists (many chat models reuse EOS as PAD)
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token = tokenizer.eos_token

def _msgs_from_history(history, system_text):
    msgs = []
    if system_text:
        msgs.append({"role": "system", "content": system_text})
    if not history:
        return msgs

    # Support both new "messages" format and legacy (user, assistant) tuples
    if isinstance(history[0], dict) and "role" in history[0] and "content" in history[0]:
        for m in history:
            role, content = m.get("role"), m.get("content")
            if role in ("user", "assistant", "system") and content:
                msgs.append({"role": role, "content": content})
    else:
        for user, assistant in history:
            if user:
                msgs.append({"role": "user", "content": user})
            if assistant:
                msgs.append({"role": "assistant", "content": assistant})
    return msgs

def _eos_ids(tok):
    # Support ints/lists and optional <|im_end|>
    ids = set()
    if tok.eos_token_id is not None:
        if isinstance(tok.eos_token_id, (list, tuple)):
            ids.update(tok.eos_token_id)
        else:
            ids.add(tok.eos_token_id)
    try:
        im_end = tok.convert_tokens_to_ids("<|im_end|>")
        if im_end is not None and im_end != tok.unk_token_id:
            ids.add(im_end)
    except Exception:
        pass
    return list(ids)

def chat_fn(message, history, system_text, temperature, top_p, max_new, min_new):
    msgs = _msgs_from_history(history, system_text) + [{"role": "user", "content": message}]
    prompt = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

    eos = _eos_ids(tokenizer)
    gen_cfg_kwargs = dict(
        do_sample=True,
        temperature=float(temperature),
        top_p=float(top_p),
        max_new_tokens=int(max_new),
        min_new_tokens=int(min_new),
        repetition_penalty=1.02,
        no_repeat_ngram_size=3,
        pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id,
    )
    if eos:
        gen_cfg_kwargs["eos_token_id"] = eos

    gen_cfg = GenerationConfig(**gen_cfg_kwargs)

    with torch.no_grad():
        out = model.generate(**inputs, generation_config=gen_cfg)

    new_tokens = out[:, inputs["input_ids"].shape[1]:]
    reply = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)[0].strip()
    return reply

@spaces.GPU()
def gradio_fn(message, history):
    # Inject the Gita system prompt here
    return chat_fn(
        message=message,
        history=history,
        system_text=GITA_SYSTEM_PROMPT,
        temperature=0.7,
        top_p=0.95,
        max_new=512,
        min_new=0,
    )

with gr.Blocks(css="""
    :root { --corner-h: min(65vh, 820px); }  /* bigger images, size caps to viewport */

    html, body { overflow-y: hidden; }       /* no vertical scroll */
    .gradio-container {
        max-width: 600px;
        margin: auto;
        padding: 20px;
        font-family: sans-serif;
        position: relative;
    }

    .chatbot { height: 500px !important; overflow-y: auto; }

    .corner {
        position: fixed;
        z-index: 9999;
        pointer-events: none;
    }
    #left  { left: 8px;  bottom: 8px; }      /* Arjuna pinned to bottom-left */
    #right { right: 8px; bottom: 88px; }     /* Krishna lifted ~80–90px */

    .corner img {
        height: var(--corner-h);
        width: auto;
        display: block;
    }

    /* safety on short screens: scale down a bit so nothing gets cramped */
    @media (max-height: 740px) {
        :root { --corner-h: 50vh; }
        #right { bottom: 72px; }
    }
""") as demo:
    ...
    gr.HTML("""
      <div id="left" class="corner">
        <img src="https://huggingface.co/spaces/JDhruv14/gita/resolve/main/arjuna.png" alt="Arjun">
      </div>
      <div id="right" class="corner">
        <img src="https://huggingface.co/spaces/JDhruv14/gita/resolve/main/krishna.png" alt="Krishna">
      </div>
    """)


if __name__ == "__main__":
    demo.launch()