| import gradio as gr |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import PeftModel |
|
|
| |
| BASE_MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" |
| ADAPTER_MODEL_ID = "Snow2222/autotrain-fst" |
|
|
| print("🚀 正在加载 Tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID) |
|
|
| print("🚀 正在加载 Base Model(基础模型)...") |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| base_model = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL_ID, |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| device_map="auto" |
| ) |
|
|
| |
| print("🔧 调整 vocab_size 以匹配 LoRA...") |
| new_vocab_size = 151665 |
| base_model.resize_token_embeddings(new_vocab_size) |
|
|
| print("🚀 正在加载 LoRA 适配器...") |
| model = PeftModel.from_pretrained(base_model, ADAPTER_MODEL_ID).to(device) |
|
|
| |
| |
| adapter_name = model.active_adapter or "default" |
|
|
| |
| peft_config = model.peft_config[adapter_name] |
| print(f"【原始】lora_alpha: {peft_config.lora_alpha}") |
| peft_config.lora_alpha = 128 |
| print(f"【更新后】lora_alpha: {peft_config.lora_alpha}") |
|
|
| |
| for module_name, module in model.named_modules(): |
| |
| if hasattr(module, "scaling") and isinstance(module.scaling, dict): |
| |
| if adapter_name in module.scaling: |
| module.scaling[adapter_name] = peft_config.lora_alpha |
| |
|
|
| def respond(message, history, system_message, max_tokens, temperature, top_p): |
| print("==== 🚀 处理用户输入 ====") |
| print(f"用户输入: {message}") |
|
|
| |
| prompt = f"{system_message}\n用户: {message}\n助手:" |
| print(f"📡 处理 Prompt: {prompt}") |
|
|
| |
| inputs = tokenizer(message, return_tensors="pt", truncation=True).to(device) |
|
|
| with torch.no_grad(): |
| output = model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| top_p=top_p |
| ) |
| response = tokenizer.decode(output[0], skip_special_tokens=True) |
| print(f"✅ 生成结果: {response}") |
|
|
| return response |
|
|
| |
| demo = gr.ChatInterface( |
| respond, |
| additional_inputs=[ |
| gr.Textbox(value="You are a friendly Chatbot.", label="System message"), |
| gr.Slider(minimum=1, maximum=1024, value=256, step=1, label="Max new tokens"), |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"), |
| ], |
| ) |
|
|
| if __name__ == "__main__": |
| print("🌍 启动 Gradio 界面...") |
| demo.launch() |
|
|