szxllm commited on
Commit
ebd97f6
·
verified ·
1 Parent(s): e68927b

Delete infer_sft.py

Browse files
Files changed (1) hide show
  1. infer_sft.py +0 -407
infer_sft.py DELETED
@@ -1,407 +0,0 @@
1
- """
2
- Flask推理界面 - 多模态Dense Transformer (适配 Qwen Tokenizer 版)
3
- """
4
-
5
- import os
6
- import torch
7
- import torch.nn.functional as F
8
- from flask import Flask, render_template, request, jsonify
9
- from transformers import AutoTokenizer
10
- from PIL import Image
11
- import json
12
- import io
13
- import base64
14
- from pathlib import Path
15
- from typing import Optional
16
-
17
- # 确保引入路径正确,根据你之前的文件结构
18
- from model import MultiModalDenseTransformer
19
- # 注意:UnifiedMultiModalPreprocessor 之前是在 continual_learning.py 中定义的
20
- # 如果你移动了它,请修改这里的导入路径
21
- from continual_learning import UnifiedMultiModalPreprocessor
22
- # 如果没有 image_transform,我们需要在这里定义或导入
23
- from torchvision import transforms
24
-
25
- # 设置国内镜像
26
- os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
27
-
28
- # 定义图像预处理 (与 training 保持一致)
29
- image_transform = transforms.Compose([
30
- transforms.Resize((224, 224)),
31
- transforms.ToTensor(),
32
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33
- ])
34
-
35
- class ModelInference:
36
- """模型推理类"""
37
-
38
- def __init__(
39
- self,
40
- checkpoint_path: str,
41
- tokenizer_name: str,
42
- config_path: Optional[str] = None,
43
- device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
44
- ):
45
- self.device = torch.device(device)
46
- print(f"Using device: {self.device}")
47
-
48
- # 1. 加载 Tokenizer (与预训练一致)
49
- print(f"Loading tokenizer: {tokenizer_name}...")
50
- try:
51
- self.tokenizer = AutoTokenizer.from_pretrained(
52
- tokenizer_name,
53
- use_fast=True,
54
- trust_remote_code=True
55
- )
56
- if self.tokenizer.pad_token is None:
57
- self.tokenizer.pad_token = self.tokenizer.eos_token
58
- self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
59
- except Exception as e:
60
- print(f"Error loading tokenizer: {e}")
61
- raise e
62
-
63
- # 2. 配置模型参数 (必须与 pretrain.py 中的配置完全一致)
64
- if config_path and Path(config_path).exists():
65
- with open(config_path, 'r') as f:
66
- self.config = json.load(f)
67
- else:
68
- # [CRITICAL] 这里使用了你在 pretrain.py 中使用的参数
69
- self.config = {
70
- 'model_dim': 1536, # 预训练设置
71
- 'vocab_size': len(self.tokenizer), # 自动适配 Qwen (约 151665)
72
- 'n_layers': 12, # 预训练设置
73
- 'n_heads': 12, # 预训练设置
74
- 'n_kv_heads': 4, # 预训练设置
75
- 'head_dim': None, # 自动计算
76
- 'max_seq_len': 512, # 预训练设置
77
- 'dropout': 0.0, # 推理时关闭 dropout
78
- 'use_moe': False, # 预训练设置
79
- 'use_adapter': False, # 预训练未开启 Adapter
80
- 'use_lora': False, # 预训练未开启 LoRA
81
- 'rope_scaling_type': "yarn", # 预训练设置
82
- 'use_multimodal_fusion': False,
83
- 'use_contrastive': False
84
- }
85
-
86
- # 3. 初始化模型结构
87
- print("Initializing model architecture...")
88
- try:
89
- self.model = MultiModalDenseTransformer(**self.config)
90
- self.preprocessor = UnifiedMultiModalPreprocessor(
91
- model_dim=self.config['model_dim']
92
- )
93
-
94
- # 4. 加载权重
95
- print(f"Loading checkpoint from {checkpoint_path}...")
96
- # weights_only=False 是为了支持加载完整的 checkpoint 字典
97
- checkpoint = torch.load(
98
- checkpoint_path,
99
- map_location=self.device,
100
- weights_only=False
101
- )
102
-
103
- # 提取 state_dict
104
- if 'model_state_dict' in checkpoint:
105
- print("Found 'model_state_dict' in checkpoint.")
106
- state_dict = checkpoint['model_state_dict']
107
- else:
108
- state_dict = checkpoint
109
-
110
- # 处理可能的键名不匹配 (如 DDP 训练产生的 'module.' 前缀)
111
- new_state_dict = {}
112
- for k, v in state_dict.items():
113
- if k.startswith('module.'):
114
- new_state_dict[k[7:]] = v
115
- else:
116
- new_state_dict[k] = v
117
-
118
- # 加载权重 (strict=False 允许忽略一些非关键的不匹配,如 loss 缓存等)
119
- missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
120
- if missing:
121
- print(f"Warning: Missing keys: {len(missing)}")
122
- if unexpected:
123
- print(f"Warning: Unexpected keys: {len(unexpected)}")
124
-
125
- self.model.to(self.device)
126
- self.preprocessor.to(self.device)
127
- self.model.eval()
128
-
129
- print("Model loaded successfully!")
130
- print(f"Total parameters: {sum(p.numel() for p in self.model.parameters()) / 1e6:.2f}M")
131
-
132
- except Exception as e:
133
- print(f"Error initializing model: {e}")
134
- raise e
135
-
136
- @torch.no_grad()
137
- def generate_text(
138
- self,
139
- prompt: str,
140
- max_new_tokens: int = 128,
141
- temperature: float = 0.7,
142
- top_k: int = 10,
143
- top_p: float = 0.9,
144
- repetition_penalty: float = 1.2,
145
- image: Optional[Image.Image] = None
146
- ) -> str:
147
- """生成文本"""
148
- formatted_prompt = f"Instruction: {prompt}\nResponse:"
149
- # 编码输入
150
- inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
151
-
152
- # 编码输入
153
- #inputs = self.tokenizer(prompt, return_tensors="pt")
154
- input_ids = inputs['input_ids'].to(self.device)
155
-
156
- # 构建 MultiModalDenseTransformer 需要的输入格式
157
- input_data = {'segments': []}
158
-
159
- # 处理图像
160
- if image is not None:
161
- if image.mode != 'RGB':
162
- image = image.convert('RGB')
163
- # 简单的图像处理
164
- image_tensor = image_transform(image).unsqueeze(0).to(self.device)
165
- # 这里假设预处理器能处理这种输入
166
- try:
167
- # process_batch 接受 (batch_data, modality_type) 并返回 segments 列表
168
- mod_segments = self.preprocessor.process_batch(image_tensor, 'image')
169
- # 将返回的 segment 列表合并到 input_data
170
- for seg in mod_segments:
171
- input_data['segments'].append(seg)
172
- except Exception as e:
173
- print(f"Warning: Image processing skipped due to error: {e}")
174
-
175
- # 添加文本段
176
- input_data['segments'].append({
177
- 'type': 'text',
178
- 'data': input_ids,
179
- 'modality_id': 0
180
- })
181
-
182
- # 生成
183
- try:
184
- # 使用模型自带的 generate 方法
185
- generated_ids = self.model.generate(
186
- input_data,
187
- max_new_tokens=max_new_tokens,
188
- temperature=temperature,
189
- top_k=top_k,
190
- top_p=top_p,
191
- repetition_penalty=repetition_penalty,
192
- do_sample=True,
193
- eos_token_id=self.tokenizer.eos_token_id,
194
- pad_token_id=self.tokenizer.pad_token_id
195
- )
196
-
197
- # 3. 解码
198
- full_output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
199
- print(f"\n====== [DEBUG 原始输出] ======\n{full_output}\n==============================\n")
200
- # 4. [关键修改] 截断逻辑 (Stop Logic)
201
- # 提取 Response 之后的部分
202
- if "Response:" in full_output:
203
- answer = full_output.split("Response:")[-1].strip()
204
- else:
205
- answer = full_output
206
-
207
- # 定义停止词列表 (根据你的图,模型喜欢生成 Instructions: 或 Ingredients:)
208
- stop_words = [
209
- "Instruction", "Input", "###", "Response",
210
- "User:", "Assistant:", "\n\n" # 双换行通常意味着一段结束
211
- ]
212
-
213
- for stop_word in stop_words:
214
- if stop_word in answer:
215
- answer = answer.split(stop_word)[0].strip()
216
-
217
- # 3. [新增] 强制去除首行重复 (解决 Echo 问题)
218
- # 如果模型第一句就是重复 Prompt,去掉它
219
- lines = answer.split('\n')
220
- if len(lines) > 0 and prompt.lower() in lines[0].lower():
221
- answer = "\n".join(lines[1:]).strip()
222
- return answer
223
-
224
- except Exception as e:
225
- print(f"Generation error: {e}")
226
- import traceback
227
- traceback.print_exc()
228
- return f"Error: {str(e)}"
229
-
230
-
231
- # 全局模型实例
232
- model_instance = None
233
- app = Flask(__name__)
234
- app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
235
-
236
- @app.route('/')
237
- def index():
238
- display_config = model_instance.config.copy() if model_instance else {}
239
- return render_template('index.html', config=display_config)
240
-
241
- @app.route('/generate', methods=['POST'])
242
- def generate():
243
- try:
244
- data = request.json
245
- prompt = data.get('prompt', '')
246
- if not prompt.strip():
247
- return jsonify({'error': '请输入提示文本'}), 400
248
-
249
- max_tokens = int(data.get('max_tokens', 100))
250
- temperature = float(data.get('temperature', 0.7))
251
- top_k = int(data.get('top_k', 40))
252
- top_p = float(data.get('top_p', 0.9))
253
- repetition_penalty = float(data.get('repetition_penalty', 1.1))
254
-
255
- image = None
256
- if 'image' in data and data['image']:
257
- try:
258
- image_data = base64.b64decode(data['image'].split(',')[1])
259
- image = Image.open(io.BytesIO(image_data))
260
- except Exception as e:
261
- print(f"Image load error: {e}")
262
-
263
- output = model_instance.generate_text(
264
- prompt, max_tokens, temperature, top_k, top_p, repetition_penalty, image
265
- )
266
- return jsonify({'output': output})
267
-
268
- except Exception as e:
269
- return jsonify({'error': str(e)}), 500
270
-
271
- def create_html_template():
272
- """写入HTML模板"""
273
- html_content = '''
274
- <!DOCTYPE html>
275
- <html lang="zh-CN">
276
- <head>
277
- <meta charset="UTF-8">
278
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
279
- <title>Model Inference</title>
280
- <style>
281
- body { font-family: sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; background: #f0f2f5; }
282
- .container { background: white; padding: 30px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); }
283
- h1 { color: #1a73e8; text-align: center; }
284
- textarea { width: 100%; padding: 10px; border: 1px solid #ddd; border-radius: 8px; margin: 10px 0; min-height: 100px; }
285
- .controls { display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin: 20px 0; background: #f8f9fa; padding: 15px; border-radius: 8px; }
286
- button { background: #1a73e8; color: white; border: none; padding: 12px 24px; border-radius: 6px; cursor: pointer; width: 100%; font-size: 16px; transition: background 0.3s; }
287
- button:hover { background: #1557b0; }
288
- button:disabled { background: #ccc; }
289
- #output { margin-top: 20px; padding: 20px; background: #f8f9fa; border-radius: 8px; white-space: pre-wrap; min-height: 100px; border: 1px solid #e0e0e0; }
290
- .loading { color: #666; font-style: italic; }
291
- </style>
292
- </head>
293
- <body>
294
- <div class="container">
295
- <h1>🚀 模型在线推理</h1>
296
-
297
- <div>
298
- <label><strong>提示词 (Prompt):</strong></label>
299
- <textarea id="prompt" placeholder="请输入你的问题..."></textarea>
300
- </div>
301
-
302
- <div class="controls">
303
- <div>
304
- <label>Max Tokens: <span id="maxTokensVal">128</span></label>
305
- <input type="range" id="maxTokens" min="32" max="1024" value="128" style="width:100%" oninput="document.getElementById('maxTokensVal').innerText=this.value">
306
- </div>
307
- <div>
308
- <label>Temperature: <span id="tempVal">0.7</span></label>
309
- <input type="range" id="temperature" min="0.1" max="1.5" step="0.1" value="0.7" style="width:100%" oninput="document.getElementById('tempVal').innerText=this.value">
310
- </div>
311
- </div>
312
-
313
- <button id="btn" onclick="generate()">生成 (Generate)</button>
314
-
315
- <div id="output">结果将显示在这里...</div>
316
- </div>
317
-
318
- <script>
319
- async function generate() {
320
- const prompt = document.getElementById('prompt').value;
321
- if(!prompt) return alert("请输入内容");
322
-
323
- const btn = document.getElementById('btn');
324
- const out = document.getElementById('output');
325
-
326
- btn.disabled = true;
327
- btn.innerText = "生成中...";
328
- out.innerHTML = '<div class="loading">正在思考中...</div>';
329
-
330
- try {
331
- const res = await fetch('/generate', {
332
- method: 'POST',
333
- headers: {'Content-Type': 'application/json'},
334
- body: JSON.stringify({
335
- prompt: prompt,
336
- max_tokens: parseInt(document.getElementById('maxTokens').value),
337
- temperature: parseFloat(document.getElementById('temperature').value)
338
- })
339
- });
340
- const data = await res.json();
341
- if(data.error) out.innerText = "Error: " + data.error;
342
- else out.innerText = data.output;
343
- } catch(e) {
344
- out.innerText = "请求失败: " + e;
345
- } finally {
346
- btn.disabled = false;
347
- btn.innerText = "生成 (Generate)";
348
- }
349
- }
350
- </script>
351
- </body>
352
- </html>
353
- '''
354
-
355
- Path('templates').mkdir(exist_ok=True)
356
- with open('templates/index.html', 'w', encoding='utf-8') as f:
357
- f.write(html_content)
358
-
359
- def main():
360
- import argparse
361
- parser = argparse.ArgumentParser()
362
- # 默认指向 pretrain 保存的 checkpoint 路径
363
- parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/posttrain/final_model.pt")
364
- parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-7B-Instruct")
365
- parser.add_argument("--port", type=int, default=5001)
366
- parser.add_argument("--host", type=str, default="0.0.0.0")
367
- args = parser.parse_args()
368
-
369
- if not Path(args.checkpoint).exists():
370
- # 尝试找最近的 step checkpoint
371
- steps = list(Path("checkpoints/pretrain").glob("step_*.pt"))
372
- if steps:
373
- print(f"未找到 final_model.pt,尝试使用最新的 checkpoint: {steps[-1]}")
374
- args.checkpoint = str(steps[-1])
375
- else:
376
- print(f"错误: 找不到检查点文件: {args.checkpoint}")
377
- return
378
- # ----------------- 新增部分开始 -----------------
379
- try:
380
- from pyngrok import ngrok, conf
381
-
382
- # 如果你在国内,ngrok 连接慢,可以配置 region='ap' (亚太) 或 'au' (澳洲)
383
- # conf.get_default().region = "ap"
384
-
385
- # 建立隧道,映射 5001 端口
386
- public_url = ngrok.connect(args.port).public_url
387
- print(f"\n========================================")
388
- print(f"🎉 公网访问地址 (发给朋友): {public_url}")
389
- print(f"========================================\n")
390
- except ImportError:
391
- print("未安装 pyngrok,无法自动生成公网链接。")
392
- print("提示: pip install pyngrok")
393
- except Exception as e:
394
- print(f"Ngrok 启动失败: {e}")
395
- # ----------------- 新增部分结束 -----------------
396
- create_html_template()
397
-
398
- global model_instance
399
- model_instance = ModelInference(args.checkpoint, args.tokenizer)
400
-
401
- print(f"\n服务已启动: http://{args.host}:{args.port}")
402
- app.run(host=args.host, port=args.port,
403
- debug=True, # 开启调试模式
404
- use_reloader=False)
405
-
406
- if __name__ == "__main__":
407
- main()