File size: 3,276 Bytes
cc6f54d
692a239
cc6f54d
73f52cd
692a239
73f52cd
92ee0f1
73f52cd
cc6f54d
 
 
bec8f6d
cc6f54d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
692a239
607fa02
cc6f54d
 
 
 
 
 
 
 
 
 
b2d905e
cc6f54d
73f52cd
cc6f54d
 
 
fd78eab
 
692a239
 
 
 
 
 
363f32b
 
692a239
607fa02
692a239
cc6f54d
692a239
363f32b
 
 
692a239
 
cc6f54d
692a239
607fa02
fd78eab
73f52cd
692a239
73f52cd
 
cc6f54d
73f52cd
cc6f54d
73f52cd
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import time
import threading
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer

MODEL_NAME = "daniel-dona/gemma-3-270m-it"

# CPU optimizasyonları
torch.set_num_threads(torch.get_num_threads())  # tüm çekirdekleri kullan
torch.set_float32_matmul_precision("high")      # matmul hızını artır

# Model/Tokenizer global yükleme
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,  # CPU'da float32
    device_map=None
)
model.eval()

# Kullanıcı bazlı KV cache
sessions = {}  # {user_id: past_key_values}

def build_prompt(message, history, system_message, max_ctx_tokens=1024):
    msgs = [{"role": "system", "content": system_message}]
    for u, a in history:
        if u: msgs.append({"role": "user", "content": u})
        if a: msgs.append({"role": "assistant", "content": a})
    msgs.append({"role": "user", "content": message})

    # Token bütçesi ile kırpma
    while True:
        text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        if len(tokenizer(text, add_special_tokens=False).input_ids) <= max_ctx_tokens:
            return text
        # En eski user+assistant çiftini at
        for i in range(1, len(msgs)):
            if msgs[i]["role"] != "system":
                del msgs[i:i+2]
                break

def respond_stream(message, history, system_message, max_tokens, temperature, top_p):
    user_id = "default"
    past = sessions.get(user_id)

    if past is None:
        text = build_prompt(message, history, system_message)
        inputs = tokenizer([text], return_tensors="pt").to(model.device)
    else:
        inputs = tokenizer([message], return_tensors="pt").to(model.device)

    do_sample = temperature > 0
    gen_kwargs = dict(
        max_new_tokens=max_tokens,
        do_sample=do_sample,
        top_p=top_p,
        temperature=temperature if do_sample else None,
        use_cache=True,
        past_key_values=past
    )

    streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
    thread = threading.Thread(
        target=model.generate,
        kwargs={**inputs, **{k: v for k, v in gen_kwargs.items() if v is not None}, "streamer": streamer}
    )

    input_len = inputs["input_ids"].shape[1]
    partial_text = ""
    token_count = 0
    start_time = time.time()

    with torch.inference_mode():
        thread.start()
        for new_text in streamer:
            partial_text += new_text
            yield partial_text
        thread.join()

    end_time = time.time()
    tps = token_count / (end_time - start_time) if (end_time - start_time) > 0 else 0
    yield partial_text + f"\n\n⚡ **Hız:** {tps:.2f} token/sn"

demo = gr.ChatInterface(
    respond_stream,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
    ],
)

if __name__ == "__main__":
    demo.launch()