File size: 4,704 Bytes
8e02470
0c8c300
8e02470
 
 
 
ff4ac49
8e02470
ff4ac49
ddd43ff
ff4ac49
8e02470
ddd43ff
8e02470
c44a1f7
0c8c300
c44a1f7
7130460
9678c10
ff4ac49
0c8c300
 
c44a1f7
0c8c300
ff4ac49
b2f5cfc
0c8c300
7130460
 
8e02470
ff4ac49
8e02470
ff4ac49
ddd43ff
d051ac2
7130460
0c8c300
ddd43ff
ff4ac49
 
0c8c300
d051ac2
8e02470
ff4ac49
4e2cd1c
ddd43ff
0c8c300
ddd43ff
4e2cd1c
8e02470
0c8c300
ddd43ff
8e02470
9678c10
8e02470
 
4e2cd1c
a42d0d2
9678c10
0c8c300
8e02470
 
a42d0d2
 
 
 
 
 
 
9678c10
 
 
 
 
 
 
 
ff4ac49
 
 
 
 
 
0c8c300
ff4ac49
 
 
 
4e2cd1c
 
ff4ac49
 
0c8c300
ff4ac49
 
 
 
 
 
7130460
 
0c8c300
 
 
 
 
4e2cd1c
 
11833aa
a42d0d2
 
 
11833aa
9678c10
11833aa
a42d0d2
 
 
 
 
 
 
 
11833aa
 
 
2e74cf3
0c8c300
ddd43ff
b2f5cfc
4e2cd1c
ddd43ff
 
 
2e74cf3
 
ddd43ff
 
a42d0d2
 
ddd43ff
11833aa
ddd43ff
c44a1f7
2e74cf3
 
11833aa
2e74cf3
11833aa
7130460
11833aa
7130460
 
2e74cf3
 
 
0c8c300
7130460
4e2cd1c
0c8c300
8e02470
 
 
d051ac2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
import time
import psutil
import os
import torch

MODEL_ID    = "Qwen/Qwen2.5-0.5B-Instruct"
model       = None
tokenizer   = None
load_status = "🔄 Initializing..."
load_start  = time.time()


def get_ram_mb() -> float:
    return psutil.Process(os.getpid()).memory_info().rss / 1024**2


def get_stats_md(tps=None, tokens=None, elapsed=None) -> str:
    mb     = get_ram_mb()
    filled = min(int(mb / 150), 10)
    bar    = "█" * filled + "░" * (10 - filled)
    s      = f"**Status:** {load_status}  \n**RAM:** `[{bar}]` **{mb:.0f} MB**"
    if tps is not None:
        s += f"  \n**Speed:** {tps:.1f} t/s · **Tokens:** {tokens} · **Elapsed:** {elapsed:.1f}s"
    return s


def load_model():
    global model, tokenizer, load_status
    try:
        load_status = "🔄 Loading tokenizer..."
        print(load_status)
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

        load_status = "🔄 Loading model weights..."
        print(load_status)
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            torch_dtype=torch.float32,
            low_cpu_mem_usage=True
        )
        model.eval()

        elapsed = time.time() - load_start
        load_status = f"✅ Ready — {get_ram_mb():.0f} MB · {elapsed:.0f}s"
        print(load_status)

    except Exception as e:
        load_status = f"❌ {e}"
        print(load_status)


Thread(target=load_model, daemon=True).start()


def chat(message: str, prior_messages: list, system_prompt: str):
    if model is None or tokenizer is None:
        yield "⏳ Still loading...", get_stats_md()
        return

    # history is now already in OpenAI dict format. Just prepend system, append user.
    messages = []
    if system_prompt.strip():
        messages.append({"role": "system", "content": system_prompt.strip()})
    messages.extend(prior_messages)
    messages.append({"role": "user", "content": message})

    prompt = tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    
    inputs = tokenizer(prompt, return_tensors="pt")
    
    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=True,
        skip_special_tokens=True
    )

    Thread(target=model.generate, kwargs=dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=512,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.1,
        pad_token_id=tokenizer.eos_token_id
    )).start()

    t0     = time.time()
    output = ""
    count  = 0
    for chunk in streamer:
        output += chunk
        count  += 1
        elapsed = time.time() - t0
        yield output, get_stats_md(
            tps=count / elapsed if elapsed > 0 else 0,
            tokens=count,
            elapsed=elapsed
        )


def user_turn(message, history):
    # Append native dictionary format
    history.append({"role": "user", "content": message})
    return "", history


def bot_turn(history, system):
    user_msg = history[-1]["content"]
    prior_history = history[:-1] # Everything except the just-added user message
    
    # Pre-allocate assistant dict so the UI knows where to stream text
    history.append({"role": "assistant", "content": ""})
    
    for text, stats in chat(user_msg, prior_history, system):
        history[-1]["content"] = text
        yield history, stats


with gr.Blocks(title="Qwen 0.5B") as demo:
    gr.Markdown("## 🧠 Qwen2.5-0.5B · CPU")

    stats_md = gr.Markdown(value=get_stats_md())

    with gr.Accordion("⚙️ System Prompt", open=False):
        system_box = gr.Textbox(
            value="You are a helpful assistant.",
            lines=3,
            show_label=False
        )

    # Added type="messages" to silence warning and structure data properly
    chatbot = gr.Chatbot(value=[], type="messages", show_label=False, height=400)

    with gr.Row():
        msg = gr.Textbox(
            placeholder="Type a message…",
            show_label=False,
            scale=9,
            lines=1
        )
        send_btn = gr.Button("➤", variant="primary", scale=1)

    clear = gr.Button("🗑️ Clear")

    for trigger in [msg.submit, send_btn.click]:
        trigger(
            user_turn, [msg, chatbot], [msg, chatbot], queue=False
        ).then(
            bot_turn, [chatbot, system_box], [chatbot, stats_md]
        )

    clear.click(lambda: ([], ""), outputs=[chatbot, msg], queue=False)


if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)