| import argparse |
| import json |
| import os |
| import sys |
|
|
| __package__ = "scripts" |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
| import time |
| import torch |
| import warnings |
| import uvicorn |
|
|
| from threading import Thread |
| from queue import Queue |
| from fastapi import FastAPI, HTTPException |
| from fastapi.responses import StreamingResponse |
| from pydantic import BaseModel |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer |
| from model.model_minimind import MiniMindConfig, MiniMindForCausalLM |
| from model.model_lora import apply_lora, load_lora |
|
|
| warnings.filterwarnings('ignore') |
|
|
| app = FastAPI() |
|
|
|
|
| def init_model(args): |
| tokenizer = AutoTokenizer.from_pretrained(args.load_from) |
| if 'model' in args.load_from: |
| moe_suffix = '_moe' if args.use_moe else '' |
| ckp = f'../{args.save_dir}/{args.weight}_{args.hidden_size}{moe_suffix}.pth' |
| model = MiniMindForCausalLM(MiniMindConfig( |
| hidden_size=args.hidden_size, |
| num_hidden_layers=args.num_hidden_layers, |
| max_seq_len=args.max_seq_len, |
| use_moe=bool(args.use_moe), |
| inference_rope_scaling=args.inference_rope_scaling |
| )) |
| model.load_state_dict(torch.load(ckp, map_location=device), strict=True) |
| if args.lora_weight != 'None': |
| apply_lora(model) |
| load_lora(model, f'../{args.save_dir}/lora/{args.lora_weight}_{args.hidden_size}.pth') |
| else: |
| model = AutoModelForCausalLM.from_pretrained(args.load_from, trust_remote_code=True) |
| print(f'MiniMind模型参数量: {sum(p.numel() for p in model.parameters()) / 1e6:.2f} M(illion)') |
| return model.eval().to(device), tokenizer |
|
|
|
|
| class ChatRequest(BaseModel): |
| model: str |
| messages: list |
| temperature: float = 0.7 |
| top_p: float = 0.92 |
| max_tokens: int = 8192 |
| stream: bool = False |
| tools: list = [] |
|
|
|
|
| class CustomStreamer(TextStreamer): |
| def __init__(self, tokenizer, queue): |
| super().__init__(tokenizer, skip_prompt=True, skip_special_tokens=True) |
| self.queue = queue |
| self.tokenizer = tokenizer |
|
|
| def on_finalized_text(self, text: str, stream_end: bool = False): |
| self.queue.put(text) |
| if stream_end: |
| self.queue.put(None) |
|
|
|
|
| def generate_stream_response(messages, temperature, top_p, max_tokens): |
| try: |
| new_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)[-max_tokens:] |
| inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device) |
|
|
| queue = Queue() |
| streamer = CustomStreamer(tokenizer, queue) |
|
|
| def _generate(): |
| model.generate( |
| inputs.input_ids, |
| max_new_tokens=max_tokens, |
| do_sample=True, |
| temperature=temperature, |
| top_p=top_p, |
| attention_mask=inputs.attention_mask, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| streamer=streamer |
| ) |
|
|
| Thread(target=_generate).start() |
|
|
| while True: |
| text = queue.get() |
| if text is None: |
| yield json.dumps({ |
| "choices": [{ |
| "delta": {}, |
| "finish_reason": "stop" |
| }] |
| }, ensure_ascii=False) |
| break |
|
|
| yield json.dumps({ |
| "choices": [{"delta": {"content": text}}] |
| }, ensure_ascii=False) |
|
|
| except Exception as e: |
| yield json.dumps({"error": str(e)}) |
|
|
|
|
| @app.post("/v1/chat/completions") |
| async def chat_completions(request: ChatRequest): |
| try: |
| if request.stream: |
| return StreamingResponse( |
| (f"data: {chunk}\n\n" for chunk in generate_stream_response( |
| messages=request.messages, |
| temperature=request.temperature, |
| top_p=request.top_p, |
| max_tokens=request.max_tokens |
| )), |
| media_type="text/event-stream" |
| ) |
| else: |
| new_prompt = tokenizer.apply_chat_template( |
| request.messages, |
| tokenize=False, |
| add_generation_prompt=True |
| )[-request.max_tokens:] |
| inputs = tokenizer(new_prompt, return_tensors="pt", truncation=True).to(device) |
| with torch.no_grad(): |
| generated_ids = model.generate( |
| inputs["input_ids"], |
| max_length=inputs["input_ids"].shape[1] + request.max_tokens, |
| do_sample=True, |
| attention_mask=inputs["attention_mask"], |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| top_p=request.top_p, |
| temperature=request.temperature |
| ) |
| answer = tokenizer.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True) |
| return { |
| "id": f"chatcmpl-{int(time.time())}", |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": "minimind", |
| "choices": [ |
| { |
| "index": 0, |
| "message": {"role": "assistant", "content": answer}, |
| "finish_reason": "stop" |
| } |
| ] |
| } |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Server for MiniMind") |
| parser.add_argument('--load_from', default='../model', type=str, help="模型加载路径(model=原生torch权重,其他路径=transformers格式)") |
| parser.add_argument('--save_dir', default='out', type=str, help="模型权重目录") |
| parser.add_argument('--weight', default='full_sft', type=str, help="权重名称前缀(pretrain, full_sft, dpo, reason, ppo_actor, grpo, spo)") |
| parser.add_argument('--lora_weight', default='None', type=str, help="LoRA权重名称(None表示不使用,可选:lora_identity, lora_medical)") |
| parser.add_argument('--hidden_size', default=512, type=int, help="隐藏层维度(512=Small-26M, 640=MoE-145M, 768=Base-104M)") |
| parser.add_argument('--num_hidden_layers', default=8, type=int, help="隐藏层数量(Small/MoE=8, Base=16)") |
| parser.add_argument('--max_seq_len', default=8192, type=int, help="最大序列长度") |
| parser.add_argument('--use_moe', default=0, type=int, choices=[0, 1], help="是否使用MoE架构(0=否,1=是)") |
| parser.add_argument('--inference_rope_scaling', default=False, action='store_true', help="启用RoPE位置编码外推(4倍,仅解决位置编码问题)") |
| parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', type=str, help="运行设备") |
| args = parser.parse_args() |
| device = args.device |
| model, tokenizer = init_model(args) |
| uvicorn.run(app, host="0.0.0.0", port=8998) |
|
|