File size: 3,668 Bytes
e8c693f
 
 
 
 
 
 
 
 
 
 
 
8a5aaef
e8c693f
 
26009d1
e8c693f
 
 
 
318503e
e8c693f
318503e
e8c693f
26009d1
e8c693f
 
 
 
 
 
 
 
26009d1
e8c693f
 
 
 
 
 
 
26009d1
318503e
e8c693f
 
 
 
 
 
 
 
 
26009d1
e8c693f
 
 
 
26009d1
e8c693f
 
 
 
0d61b36
e8c693f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318503e
e8c693f
 
0d61b36
e8c693f
 
 
 
0d61b36
e8c693f
 
318503e
e8c693f
 
 
 
 
 
 
 
 
26009d1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
torch._dynamo.config.disable = True
from collections.abc import Iterator
from transformers import (
    Gemma3ForConditionalGeneration,
    TextIteratorStreamer,
    Gemma3Processor,
    Gemma3nForConditionalGeneration,
)
import gradio as gr
import os
import spaces

# Load environment variables
model_3n_id = os.getenv("MODEL_3N_ID", "JDhruv14/merged_model")

# Load model and processor
model_3n = Gemma3nForConditionalGeneration.from_pretrained(
    model_3n_id,
    dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="eager"
)
input_processor = Gemma3Processor.from_pretrained(model_3n_id)

def infer_text(messages, max_new_tokens=300, temperature=1.0, top_p=0.95, top_k=64, repetition_penalty=1.1):
    chat_template = []
    for turn in messages:
        if turn[0]:
            chat_template.append({"role": "user", "content": [{"type": "text", "text": turn[0]}]})
        if turn[1]:
            chat_template.append({"role": "assistant", "content": [{"type": "text", "text": turn[1]}]})
    chat_template.append({"role": "assistant", "content": [{"type": "text", "text": ""}]})

    inputs = input_processor.apply_chat_template(
        chat_template,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(device=model_3n.device, dtype=torch.bfloat16)

    with torch.no_grad():
        output_tokens = model_3n.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=repetition_penalty,
            do_sample=True,
        )

    generated_text = input_processor.batch_decode(
        output_tokens[:, inputs['input_ids'].shape[1]:], skip_special_tokens=True
    )[0]
    return generated_text.strip()

@spaces.GPU()
def gradio_fn(message, history):
    response = infer_text(history + [(message, None)])
    return response

with gr.Blocks(css="""
    .gradio-container {
        max-width: 600px;
        margin: auto;
        padding: 20px;
        font-family: sans-serif;
        position: relative;
        }
    .chatbot {
        height: 500px !important;
        overflow-y: auto;
        }
    .corner {
       position: fixed;
       bottom: 2px;
       z-index: 9999;
       pointer-events: none;
        } 
    #left { left: 2px; }
    #right { right: 2px; }
    .corner img {
       height: 500px;  /* fixed height */
       width: auto;    /* auto to keep aspect ratio */
        }
    
    """) as demo:
    gr.Markdown(
    """
        <div style='text-align: center; padding: 10px;'>
        <h1 style='font-size: 2.2em; margin-bottom: 0.2em;'>🤖 <span style='color: #4F46E5;'>kRISHNA.ai</span></h1>
        <p style='font-size: 1.1em; color: #555;'>5000-Years of Ancient WISDOM with Modern AI ✨</p>
        </div>
    """,
    elem_id="header"
    )
    chat = gr.ChatInterface(
        fn=gradio_fn,
        examples=[
            "Hello!",
            "How can I overcome fear of failure?",
            "How do I forgive someone who hurt me deeply?",
            "What can I do to stop overthinking?"
        ],
        chatbot=gr.Chatbot(elem_classes="chatbot"),
        theme="compact",
    )
    gr.HTML(f"""
      <div id="left" class="corner">
        <img src="https://huggingface.co/spaces/p2kalita/kRISHNA.ai/resolve/main/assets/Arjun.png" alt="Arjun">
      </div>
      <div id="right" class="corner">
        <img src="https://huggingface.co/spaces/p2kalita/kRISHNA.ai/resolve/main/assets/Krishna.png" alt="Krishna">
      </div>
    """)


if __name__ == "__main__":
    demo.launch()