File size: 3,357 Bytes
cc6f54d
 
73f52cd
cc6f54d
73f52cd
cc6f54d
73f52cd
cc6f54d
 
 
bec8f6d
cc6f54d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2d905e
cc6f54d
73f52cd
cc6f54d
 
 
fd78eab
 
cc6f54d
 
 
 
 
fd78eab
cc6f54d
 
fd78eab
cc6f54d
 
 
fd78eab
cc6f54d
 
 
 
fd78eab
cc6f54d
fd78eab
73f52cd
 
 
 
cc6f54d
73f52cd
cc6f54d
73f52cd
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import time
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_NAME = "google/gemma-3-270m-it"

# CPU optimizasyonları
torch.set_num_threads(torch.get_num_threads())  # tüm çekirdekleri kullan
torch.set_float32_matmul_precision("high")      # matmul hızını artır

# Model/Tokenizer global yükleme
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float32,  # CPU'da float32
    device_map=None
)
model.eval()

# Kullanıcı bazlı KV cache
sessions = {}  # {user_id: past_key_values}

def build_prompt(message, history, system_message, max_ctx_tokens=1024):
    msgs = [{"role": "system", "content": system_message}]
    for u, a in history:
        if u: msgs.append({"role": "user", "content": u})
        if a: msgs.append({"role": "assistant", "content": a})
    msgs.append({"role": "user", "content": message})

    # Token bütçesi ile kırpma
    while True:
        text = tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        if len(tokenizer(text, add_special_tokens=False).input_ids) <= max_ctx_tokens:
            return text
        # En eski user+assistant çiftini at
        for i in range(1, len(msgs)):
            if msgs[i]["role"] != "system":
                del msgs[i:i+2]
                break

def respond(message, history, system_message, max_tokens, temperature, top_p):
    user_id = "default"  # API bağlarsan burada kullanıcı ID'si ile değiştir
    past = sessions.get(user_id)

    if past is None:
        # İlk mesaj → tüm prompt
        text = build_prompt(message, history, system_message)
        inputs = tokenizer([text], return_tensors="pt").to(model.device)
    else:
        # Sadece yeni mesajı encode et
        inputs = tokenizer([message], return_tensors="pt").to(model.device)

    do_sample = temperature > 0
    gen_kwargs = dict(
        max_new_tokens=max_tokens,
        do_sample=do_sample,
        top_p=top_p,
        temperature=temperature if do_sample else None,
        use_cache=True,
        past_key_values=past
    )

    start_time = time.time()
    with torch.inference_mode():
        outputs = model.generate(**inputs, **{k: v for k, v in gen_kwargs.items() if v is not None},
                                 return_dict_in_generate=True, output_scores=False)
    end_time = time.time()

    # KV cache güncelle
    sessions[user_id] = outputs.past_key_values

    # Yanıtı decode et
    new_tokens = outputs.sequences[0][inputs["input_ids"].shape[1]:]
    content = tokenizer.decode(new_tokens, skip_special_tokens=True).strip("\n")

    # T/S hesapla
    token_count = len(new_tokens)
    elapsed = end_time - start_time
    tps = token_count / elapsed if elapsed > 0 else 0

    return f"{content}\n\n⚡ **Hız:** {tps:.2f} token/sn"

demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
    ],
)

if __name__ == "__main__":
    demo.launch()