Spaces:
Runtime error
Runtime error
| import os, torch, gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from peft import PeftModel | |
| BASE_MODEL = os.getenv("BASE_MODEL", "Qwen/Qwen2.5-7B") | |
| ADAPTER_REPO = os.getenv("ADAPTER_REPO", "your-username/tt-qwen25-7b-tt-lora") | |
| LOAD_IN_4BIT = os.getenv("LOAD_IN_4BIT", "true").lower() == "true" | |
| def load_model(): | |
| tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False) | |
| if tok.pad_token_id is None: | |
| tok.pad_token = tok.eos_token | |
| base = None | |
| if LOAD_IN_4BIT: | |
| try: | |
| bnb_cfg = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, # float16 для Spaces GPU | |
| ) | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, quantization_config=bnb_cfg, device_map="auto" | |
| ) | |
| print("Loaded base in 4-bit NF4") | |
| except Exception as e: | |
| print("[warn] 4-bit failed:", e) | |
| if base is None: | |
| try: | |
| bnb8 = BitsAndBytesConfig(load_in_8bit=True) | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, quantization_config=bnb8, device_map="auto" | |
| ) | |
| print("Loaded base in 8-bit") | |
| except Exception as e: | |
| print("[warn] 8-bit failed:", e) | |
| base = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, torch_dtype=torch.float16, device_map="auto" | |
| ) | |
| print("Loaded base in FP16 (may offload to CPU)") | |
| base.config.pad_token_id = tok.pad_token_id | |
| model = PeftModel.from_pretrained( | |
| base, ADAPTER_REPO, is_trainable=False, torch_dtype=torch.float16 | |
| ) | |
| model = model.to(dtype=torch.float16) | |
| model.eval() | |
| return tok, model | |
| tok, model = load_model() | |
| def format_prompt(user, system, mode): | |
| if mode == "Qwen chat": | |
| msgs = [{"role":"system","content":system},{"role":"user","content":user}] | |
| input_ids = tok.apply_chat_template(msgs, add_generation_prompt=True, return_tensors="pt") | |
| attn = torch.ones_like(input_ids) | |
| return {"input_ids": input_ids.to(model.device), "attention_mask": attn.to(model.device)} | |
| else: | |
| prompt = f"<|system|> {system}\n<|user|> {user}\n<|assistant|>" | |
| enc = tok(prompt, return_tensors="pt") | |
| return { | |
| "input_ids": enc["input_ids"].to(model.device), | |
| "attention_mask": enc["attention_mask"].to(model.device) | |
| } | |
| def respond(message, history, system_prompt, mode, temperature, top_p, rep_penalty, max_new_tokens): | |
| inputs = format_prompt(message, system_prompt, mode) | |
| with torch.autocast("cuda", dtype=torch.float16): | |
| out = model.generate( | |
| **inputs, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=rep_penalty, | |
| max_new_tokens=max_new_tokens, | |
| pad_token_id=tok.pad_token_id, | |
| eos_token_id=tok.eos_token_id, | |
| no_repeat_ngram_size=4 | |
| ) | |
| gen_only = out[0][inputs["input_ids"].shape[1]:] | |
| text = tok.decode(gen_only, skip_special_tokens=True) | |
| return text | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Татарча чат-демо (Qwen2.5-7B + LoRA)") | |
| gr.Markdown("Бета-версия. Модель обучена отвечать **по-татарски**. Если переключаться на русский/английский — это ошибка; сообщите нам примеры.") | |
| with gr.Row(): | |
| system_prompt = gr.Textbox( | |
| value="Син бары тик татарча гына җавап бир. Җавапларың кыска һәм нейтраль булсын.", | |
| label="System prompt" | |
| ) | |
| mode = gr.Radio(choices=["SFT tags", "Qwen chat"], value="SFT tags", label="Формат промпта") | |
| with gr.Row(): | |
| temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.05, label="temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p") | |
| rep_penalty = gr.Slider(1.0, 1.4, value=1.15, step=0.05, label="repetition_penalty") | |
| max_new_tokens = gr.Slider(16, 512, value=200, step=8, label="max_new_tokens") | |
| gr.ChatInterface( | |
| fn=respond, | |
| additional_inputs=[system_prompt, mode, temperature, top_p, rep_penalty, max_new_tokens], | |
| title=None, undo_btn=None, retry_btn=None, clear_btn="Clear" | |
| ) | |
| demo.queue(concurrency_count=1).launch() | |