| import gradio as gr |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import json |
| import os |
| from train import ModelTrainer |
|
|
| class NovelAIApp: |
| def __init__(self): |
| self.model = None |
| self.tokenizer = None |
| self.trainer = None |
| |
| |
| with open('configs/system_prompts.json', 'r', encoding='utf-8') as f: |
| self.system_prompts = json.load(f) |
| |
| |
| self.current_mood = "暗示" |
|
|
| def load_model(self, model_path): |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| model_path, |
| trust_remote_code=True |
| ) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| trust_remote_code=True, |
| load_in_8bit=True, |
| device_map="auto" |
| ) |
|
|
| def train_model(self, files): |
| if not self.trainer: |
| self.trainer = ModelTrainer( |
| "THUDM/chatglm2-6b", |
| "configs/system_prompts.json" |
| ) |
| |
| dataset = self.trainer.prepare_dataset(files) |
| self.trainer.train(dataset) |
| return "训练完成!" |
|
|
| def generate_text(self, message, history): |
| """修改后的生成文本方法,适配 ChatInterface""" |
| if not self.model: |
| return "请先加载模型!" |
| |
| system_prompt = self.system_prompts.get("base_prompt") |
| |
| |
| full_history = "" |
| for msg in history: |
| full_history += f"<|user|>{msg[0]}</|user|>\n<|assistant|>{msg[1]}</|assistant|>\n" |
| |
| formatted_prompt = f"""<|system|>{system_prompt}</|system|> |
| {full_history}<|user|>{message}</|user|> |
| <|assistant|>""" |
|
|
| inputs = self.tokenizer(formatted_prompt, return_tensors="pt") |
| outputs = self.model.generate( |
| inputs["input_ids"], |
| max_length=1024, |
| temperature=0.7, |
| top_p=0.9, |
| repetition_penalty=1.1 |
| ) |
| |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| response = response.split("<|assistant|>")[-1].strip() |
| return response |
|
|
| def create_interface(self): |
| """创建 Gradio 界面""" |
| with gr.Blocks() as interface: |
| gr.Markdown("# 猫娘对话助手") |
| |
| with gr.Tab("模型训练"): |
| file_output = gr.File( |
| file_count="multiple", |
| label="上传小说文本文件" |
| ) |
| train_button = gr.Button("开始训练") |
| train_output = gr.Textbox(label="训练状态") |
| |
| train_button.click( |
| fn=self.train_model, |
| inputs=[file_output], |
| outputs=[train_output] |
| ) |
| |
| with gr.Tab("对话"): |
| chatbot = gr.ChatInterface( |
| fn=self.generate_text, |
| title="与猫娘对话", |
| description="来和可爱的猫娘聊天吧~", |
| theme="soft", |
| examples=["今天天气真好呢", "你在做什么呢?", "要不要一起玩?"], |
| cache_examples=False, |
| type="messages" |
| ) |
|
|
| return interface |
|
|
| |
| app = NovelAIApp() |
| interface = app.create_interface() |
|
|
| |
| interface.launch( |
| server_name="0.0.0.0", |
| share=True, |
| ssl_verify=False |
| ) |