File size: 4,007 Bytes
98fad21
 
 
afacd0f
98fad21
afacd0f
98fad21
 
a79db8e
 
afacd0f
98fad21
 
 
 
 
afacd0f
98fad21
 
 
 
 
 
 
 
 
 
 
 
 
afacd0f
 
 
 
 
 
 
 
 
 
 
 
98fad21
 
afacd0f
98fad21
afacd0f
 
98fad21
 
 
afacd0f
 
98fad21
afacd0f
 
 
 
 
 
98fad21
 
afacd0f
44436a9
afacd0f
 
 
 
98fad21
afacd0f
 
 
 
 
 
 
 
 
 
 
 
98fad21
 
a761dfe
afacd0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98fad21
 
 
afacd0f
 
98fad21
 
afacd0f
98fad21
 
afacd0f
98fad21
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# server.py
import torch
import threading
import time
import numpy as np
import re
import json
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import gradio as gr

# === Модель ===
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="cpu",
    torch_dtype=torch.float16,
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

try:
    model = torch.compile(model, mode="reduce-overhead", fullgraph=True)
    print("✅ torch.compile активирован")
except:
    pass

# === Tools ===
tools = [{
    "name": "get_weather",
    "parameters": {"type": "object", "properties": {"city": {"type": "string"}}, "required": ["city"]}
}]

def execute_tool_call(call):
    city = call.get("arguments", {}).get("city", "неизвестен")
    return f"🌤️ Погода в {city}: 22°C, солнечно. (симуляция)"

# === NumPy-парсер ===
def find_and_replace_tool_calls_numpy(buffer):
    chars = np.array(list(buffer), dtype='U1')
    indices = np.where(chars == '<tool_call>')[0]
    if len(indices) < 2 or len(indices) % 2 != 0:
        return buffer, False

    new_buffer = buffer
    replaced = False
    for i in range(0, len(indices) - 1, 2):
        start, end = indices[i], indices[i + 1] + 2
        if end > len(buffer): continue
        block = buffer[start:end]
        content = buffer[start+2:end-2].strip()
        try:
            json_match = re.search(r'\{.*\}', content, re.DOTALL)
            if json_match:
                data = json.loads(json_match.group())
                result = execute_tool_call(data)
                new_buffer = new_buffer.replace(block, f"\n\n✅ {result}\n\n")
                replaced = True
        except:
            pass
    return new_buffer, replaced

# === Генерация с "GPU-эффектом" ===
def generate_stream(prompt, max_new_tokens=128, temperature=0.7, top_p=0.9):
    messages = [{"role": "user", "content": prompt}]
    inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(model.device)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    thread = threading.Thread(target=model.generate, kwargs={
        "input_ids": inputs,
        "max_new_tokens": max_new_tokens,
        "temperature": temperature,
        "top_p": top_p,
        "do_sample": True,
        "pad_token_id": tokenizer.pad_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "streamer": streamer,
        "use_cache": True
    })
    thread.start()

    buffer = ""
    full_text = ""
    last_yield = time.time()

    for token in streamer:
        buffer += token
        full_text += token

        # NumPy обработка
        if "<tool_call>" in buffer:
            processed, changed = find_and_replace_tool_calls_numpy(full_text)
            if changed:
                full_text = processed
                buffer = ""
                yield full_text
                continue

        now = time.time()
        if (len(buffer) >= 30 or 
            any(p in buffer for p in ".!?;\n") or 
            now - last_yield > 0.7):
            yield full_text
            buffer = ""
            last_yield = now

    if full_text:
        yield full_text

# === Gradio ===
with gr.Blocks() as demo:
    prompt = gr.Textbox(label="Ввод", placeholder="Спроси что-нибудь...")
    max_t = gr.Slider(64, 256, 128, step=32, label="Max Tokens")
    temp = gr.Slider(0.1, 1.5, 0.7, step=0.1, label="Temperature")
    top_p = gr.Slider(0.5, 1.0, 0.9, step=0.05, label="Top-p")
    btn = gr.Button("🚀 GPU-режим")
    output = gr.Textbox(label="Ответ")

    btn.click(generate_stream, [prompt, max_t, temp, top_p], output)

if __name__ == "__main__":
    demo.launch()