File size: 4,685 Bytes
736e275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110

import os, torch, gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel

BASE_MODEL   = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-7B")
ADAPTER_REPO = os.getenv("ADAPTER_REPO", "your-username/tt-qwen25-7b-tt-lora")
LOAD_IN_4BIT = os.getenv("LOAD_IN_4BIT", "true").lower() == "true"

def load_model():
    tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token

    base = None
    if LOAD_IN_4BIT:
        try:
            bnb_cfg = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_use_double_quant=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16,  # float16 для Spaces GPU
            )
            base = AutoModelForCausalLM.from_pretrained(
                BASE_MODEL, quantization_config=bnb_cfg, device_map="auto"
            )
            print("Loaded base in 4-bit NF4")
        except Exception as e:
            print("[warn] 4-bit failed:", e)

    if base is None:
        try:
            bnb8 = BitsAndBytesConfig(load_in_8bit=True)
            base = AutoModelForCausalLM.from_pretrained(
                BASE_MODEL, quantization_config=bnb8, device_map="auto"
            )
            print("Loaded base in 8-bit")
        except Exception as e:
            print("[warn] 8-bit failed:", e)
            base = AutoModelForCausalLM.from_pretrained(
                BASE_MODEL, torch_dtype=torch.float16, device_map="auto"
            )
            print("Loaded base in FP16 (may offload to CPU)")

    base.config.pad_token_id = tok.pad_token_id
    model = PeftModel.from_pretrained(
        base, ADAPTER_REPO, is_trainable=False, torch_dtype=torch.float16
    )
    model = model.to(dtype=torch.float16)
    model.eval()
    return tok, model

tok, model = load_model()

def format_prompt(user, system, mode):
    if mode == "Qwen chat":
        msgs = [{"role":"system","content":system},{"role":"user","content":user}]
        input_ids = tok.apply_chat_template(msgs, add_generation_prompt=True, return_tensors="pt")
        attn = torch.ones_like(input_ids)
        return {"input_ids": input_ids.to(model.device), "attention_mask": attn.to(model.device)}
    else:
        prompt = f"<|system|> {system}\n<|user|> {user}\n<|assistant|>"
        enc = tok(prompt, return_tensors="pt")
        return {
            "input_ids": enc["input_ids"].to(model.device),
            "attention_mask": enc["attention_mask"].to(model.device)
        }

@torch.inference_mode()
def respond(message, history, system_prompt, mode, temperature, top_p, rep_penalty, max_new_tokens):
    inputs = format_prompt(message, system_prompt, mode)
    with torch.autocast("cuda", dtype=torch.float16):
        out = model.generate(
            **inputs,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=rep_penalty,
            max_new_tokens=max_new_tokens,
            pad_token_id=tok.pad_token_id,
            eos_token_id=tok.eos_token_id,
            no_repeat_ngram_size=4
        )
    gen_only = out[0][inputs["input_ids"].shape[1]:]
    text = tok.decode(gen_only, skip_special_tokens=True)
    return text

with gr.Blocks() as demo:
    gr.Markdown("## Татарча чат-демо (Qwen2.5-7B + LoRA)")
    gr.Markdown("Бета-версия. Модель обучена отвечать **по-татарски**. Если переключаться на русский/английский — это ошибка; сообщите нам примеры.")
    with gr.Row():
        system_prompt = gr.Textbox(
            value="Син бары тик татарча гына җавап бир. Җавапларың кыска һәм нейтраль булсын.",
            label="System prompt"
        )
        mode = gr.Radio(choices=["SFT tags", "Qwen chat"], value="SFT tags", label="Формат промпта")
    with gr.Row():
        temperature      = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature")
        top_p            = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
        rep_penalty      = gr.Slider(1.0, 1.4, value=1.15, step=0.05, label="repetition_penalty")
        max_new_tokens   = gr.Slider(16, 512, value=200, step=8, label="max_new_tokens")

    gr.ChatInterface(
        fn=respond,
        additional_inputs=[system_prompt, mode, temperature, top_p, rep_penalty, max_new_tokens],
        title=None, undo_btn=None, retry_btn=None, clear_btn="Clear"
    )

demo.queue(concurrency_count=1).launch()