File size: 2,140 Bytes
4ddd7a9
 
 
 
 
 
 
 
 
0be43d1
 
4ddd7a9
 
 
 
 
 
0be43d1
 
 
4ddd7a9
0be43d1
4ddd7a9
b3e26e8
 
 
 
 
 
 
 
 
 
 
 
4ddd7a9
0be43d1
 
b3e26e8
0be43d1
 
4ddd7a9
 
 
0be43d1
 
 
 
4ddd7a9
 
0be43d1
 
 
 
 
 
 
 
 
 
 
 
 
 
4ddd7a9
 
 
 
 
 
 
 
 
 
0be43d1
 
 
 
 
 
 
 
4ddd7a9
0be43d1
 
 
4ddd7a9
0be43d1
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import subprocess, sys

subprocess.check_call([
    sys.executable, "-m", "pip", "install", "--quiet",
    "transformers>=4.45.0",
    "accelerate>=0.26.0",
    "sentencepiece>=0.1.99",
])

import gradio as gr
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TextIteratorStreamer,
)
from threading import Thread

MODEL_ID = "google/gemma-4-31B-it-assistant"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Manually set Gemma chat template since tokenizer doesn't include one
tokenizer.chat_template = (
    "{% for message in messages %}"
    "{% if message['role'] == 'user' %}"
    "user\n{{ message['content'] }}\n"
    "{% elif message['role'] == 'assistant' %}"
    "model\n{{ message['content'] }}\n"
    "{% endif %}"
    "{% endfor %}"
    "{% if add_generation_prompt %}model\n{% endif %}"
)

print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    dtype=torch.bfloat16,
    device_map="auto",
)
model.eval()
print("Model ready.")


def chat(message, history):
    messages = []
    for user_msg, bot_msg in history:
        messages.append({"role": "user",      "content": user_msg})
        messages.append({"role": "assistant", "content": bot_msg})
    messages.append({"role": "user", "content": message})

    inputs = tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
        add_generation_prompt=True,
    ).to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=True,
        skip_special_tokens=True,
    )

    thread = Thread(
        target=model.generate,
        kwargs=dict(
            input_ids=inputs,
            streamer=streamer,
            max_new_tokens=512,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
        ),
    )
    thread.start()

    partial = ""
    for token in streamer:
        partial += token
        yield partial


demo = gr.ChatInterface(
    fn=chat,
    title="Gemma 4 Assistant",
    description="google/gemma-4-31B-it-assistant — streaming enabled",
)

demo.launch()