szxllm commited on
Commit
9c85325
·
verified ·
1 Parent(s): ebd97f6

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +16 -57
infer.py CHANGED
@@ -1,7 +1,3 @@
1
- """
2
- Flask推理界面 - 多模态Dense Transformer (适配 Qwen Tokenizer 版)
3
- """
4
-
5
  import os
6
  import torch
7
  import torch.nn.functional as F
@@ -14,18 +10,12 @@ import base64
14
  from pathlib import Path
15
  from typing import Optional
16
 
17
- # 确保引入路径正确,根据你之前的文件结构
18
  from model import MultiModalDenseTransformer
19
- # 注意:UnifiedMultiModalPreprocessor 之前是在 continual_learning.py 中定义的
20
- # 如果你移动了它,请修改这里的导入路径
21
  from continual_learning import UnifiedMultiModalPreprocessor
22
- # 如果没有 image_transform,我们需要在这里定义或导入
23
  from torchvision import transforms
24
 
25
- # 设置国内镜像
26
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
27
 
28
- # 定义图像预处理 (与 training 保持一致)
29
  image_transform = transforms.Compose([
30
  transforms.Resize((224, 224)),
31
  transforms.ToTensor(),
@@ -33,8 +23,6 @@ image_transform = transforms.Compose([
33
  ])
34
 
35
  class ModelInference:
36
- """模型推理类"""
37
-
38
  def __init__(
39
  self,
40
  checkpoint_path: str,
@@ -44,8 +32,6 @@ class ModelInference:
44
  ):
45
  self.device = torch.device(device)
46
  print(f"Using device: {self.device}")
47
-
48
- # 1. 加载 Tokenizer (与预训练一致)
49
  print(f"Loading tokenizer: {tokenizer_name}...")
50
  try:
51
  self.tokenizer = AutoTokenizer.from_pretrained(
@@ -59,26 +45,24 @@ class ModelInference:
59
  except Exception as e:
60
  print(f"Error loading tokenizer: {e}")
61
  raise e
62
-
63
- # 2. 配置模型参数 (必须与 pretrain.py 中的配置完全一致)
64
  if config_path and Path(config_path).exists():
65
  with open(config_path, 'r') as f:
66
  self.config = json.load(f)
67
  else:
68
- # [CRITICAL] 这里使用了你在 pretrain.py 中使用的参数
69
  self.config = {
70
- 'model_dim': 1536, # 预训练设置
71
- 'vocab_size': len(self.tokenizer), # 自动适配 Qwen (约 151665)
72
- 'n_layers': 12, # 预训练设置
73
- 'n_heads': 12, # 预训练设置
74
- 'n_kv_heads': 4, # 预训练设置
75
- 'head_dim': None, # 自动计算
76
- 'max_seq_len': 512, # 预训练设置
77
- 'dropout': 0.0, # 推理时关闭 dropout
78
- 'use_moe': False, # 预训练设置
79
- 'use_adapter': False, # 预训练未开启 Adapter
80
- 'use_lora': False, # 预训练未开启 LoRA
81
- 'rope_scaling_type': "yarn" # 预训练设置
82
  }
83
 
84
  # 3. 初始化模型结构
@@ -91,21 +75,18 @@ class ModelInference:
91
 
92
  # 4. 加载权重
93
  print(f"Loading checkpoint from {checkpoint_path}...")
94
- # weights_only=False 是为了支持加载完整的 checkpoint 字典
95
  checkpoint = torch.load(
96
  checkpoint_path,
97
  map_location=self.device,
98
  weights_only=False
99
  )
100
 
101
- # 提取 state_dict
102
  if 'model_state_dict' in checkpoint:
103
  print("Found 'model_state_dict' in checkpoint.")
104
  state_dict = checkpoint['model_state_dict']
105
  else:
106
  state_dict = checkpoint
107
 
108
- # 处理可能的键名不匹配 (如 DDP 训练产生的 'module.' 前缀)
109
  new_state_dict = {}
110
  for k, v in state_dict.items():
111
  if k.startswith('module.'):
@@ -113,7 +94,6 @@ class ModelInference:
113
  else:
114
  new_state_dict[k] = v
115
 
116
- # 加载权重 (strict=False 允许忽略一些非关键的不匹配,如 loss 缓存等)
117
  missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
118
  if missing:
119
  print(f"Warning: Missing keys: {len(missing)}")
@@ -143,40 +123,29 @@ class ModelInference:
143
  image: Optional[Image.Image] = None
144
  ) -> str:
145
  """生成文本"""
146
-
147
- # 编码输入
148
  inputs = self.tokenizer(prompt, return_tensors="pt")
149
  input_ids = inputs['input_ids'].to(self.device)
150
-
151
- # 构建 MultiModalDenseTransformer 需要的输入格式
152
  input_data = {'segments': []}
153
 
154
  # 处理图像
155
  if image is not None:
156
  if image.mode != 'RGB':
157
  image = image.convert('RGB')
158
- # 简单的图像处理
159
  image_tensor = image_transform(image).unsqueeze(0).to(self.device)
160
- # 这里假设预处理器能处理这种输入
161
  try:
162
- # process_batch 接受 (batch_data, modality_type) 并返回 segments 列表
163
  mod_segments = self.preprocessor.process_batch(image_tensor, 'image')
164
- # 将返回的 segment 列表合并到 input_data
165
  for seg in mod_segments:
166
  input_data['segments'].append(seg)
167
  except Exception as e:
168
  print(f"Warning: Image processing skipped due to error: {e}")
169
-
170
- # 添加文本段
171
  input_data['segments'].append({
172
  'type': 'text',
173
  'data': input_ids,
174
  'modality_id': 0
175
  })
176
 
177
- # 生成
178
  try:
179
- # 使用模型自带的 generate 方法
180
  generated_ids = self.model.generate(
181
  input_data,
182
  max_new_tokens=max_new_tokens,
@@ -188,19 +157,11 @@ class ModelInference:
188
  eos_token_id=self.tokenizer.eos_token_id,
189
  pad_token_id=self.tokenizer.pad_token_id
190
  )
191
-
192
- # 解码
193
- # 注意:生成的 ids 可能包含原始输入,或者只包含新生成的 token
194
- # MultiModalDenseTransformer.generate 通常返回完整的序列
195
  generated_text = self.tokenizer.decode(
196
  generated_ids[0],
197
  skip_special_tokens=True
198
  )
199
-
200
- # 如果包含 prompt,可以选择移除它只显示新内容
201
- # if generated_text.startswith(prompt):
202
- # generated_text = generated_text[len(prompt):]
203
-
204
  return generated_text
205
 
206
  except Exception as e:
@@ -209,8 +170,6 @@ class ModelInference:
209
  traceback.print_exc()
210
  return f"Error: {str(e)}"
211
 
212
-
213
- # 全局模型实例
214
  model_instance = None
215
  app = Flask(__name__)
216
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
@@ -274,7 +233,7 @@ def create_html_template():
274
  </head>
275
  <body>
276
  <div class="container">
277
- <h1>🚀 模型在线推理</h1>
278
 
279
  <div>
280
  <label><strong>提示词 (Prompt):</strong></label>
 
 
 
 
 
1
  import os
2
  import torch
3
  import torch.nn.functional as F
 
10
  from pathlib import Path
11
  from typing import Optional
12
 
 
13
  from model import MultiModalDenseTransformer
 
 
14
  from continual_learning import UnifiedMultiModalPreprocessor
 
15
  from torchvision import transforms
16
 
 
17
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
18
 
 
19
  image_transform = transforms.Compose([
20
  transforms.Resize((224, 224)),
21
  transforms.ToTensor(),
 
23
  ])
24
 
25
  class ModelInference:
 
 
26
  def __init__(
27
  self,
28
  checkpoint_path: str,
 
32
  ):
33
  self.device = torch.device(device)
34
  print(f"Using device: {self.device}")
 
 
35
  print(f"Loading tokenizer: {tokenizer_name}...")
36
  try:
37
  self.tokenizer = AutoTokenizer.from_pretrained(
 
45
  except Exception as e:
46
  print(f"Error loading tokenizer: {e}")
47
  raise e
48
+
 
49
  if config_path and Path(config_path).exists():
50
  with open(config_path, 'r') as f:
51
  self.config = json.load(f)
52
  else:
 
53
  self.config = {
54
+ 'model_dim': 1536,
55
+ 'vocab_size': len(self.tokenizer),
56
+ 'n_layers': 12,
57
+ 'n_heads': 12,
58
+ 'n_kv_heads': 4,
59
+ 'head_dim': None,
60
+ 'max_seq_len': 512,
61
+ 'dropout': 0.0,
62
+ 'use_moe': False,
63
+ 'use_adapter': False,
64
+ 'use_lora': False,
65
+ 'rope_scaling_type': "yarn"
66
  }
67
 
68
  # 3. 初始化模型结构
 
75
 
76
  # 4. 加载权重
77
  print(f"Loading checkpoint from {checkpoint_path}...")
 
78
  checkpoint = torch.load(
79
  checkpoint_path,
80
  map_location=self.device,
81
  weights_only=False
82
  )
83
 
 
84
  if 'model_state_dict' in checkpoint:
85
  print("Found 'model_state_dict' in checkpoint.")
86
  state_dict = checkpoint['model_state_dict']
87
  else:
88
  state_dict = checkpoint
89
 
 
90
  new_state_dict = {}
91
  for k, v in state_dict.items():
92
  if k.startswith('module.'):
 
94
  else:
95
  new_state_dict[k] = v
96
 
 
97
  missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
98
  if missing:
99
  print(f"Warning: Missing keys: {len(missing)}")
 
123
  image: Optional[Image.Image] = None
124
  ) -> str:
125
  """生成文本"""
 
 
126
  inputs = self.tokenizer(prompt, return_tensors="pt")
127
  input_ids = inputs['input_ids'].to(self.device)
 
 
128
  input_data = {'segments': []}
129
 
130
  # 处理图像
131
  if image is not None:
132
  if image.mode != 'RGB':
133
  image = image.convert('RGB')
 
134
  image_tensor = image_transform(image).unsqueeze(0).to(self.device)
 
135
  try:
 
136
  mod_segments = self.preprocessor.process_batch(image_tensor, 'image')
 
137
  for seg in mod_segments:
138
  input_data['segments'].append(seg)
139
  except Exception as e:
140
  print(f"Warning: Image processing skipped due to error: {e}")
141
+
 
142
  input_data['segments'].append({
143
  'type': 'text',
144
  'data': input_ids,
145
  'modality_id': 0
146
  })
147
 
 
148
  try:
 
149
  generated_ids = self.model.generate(
150
  input_data,
151
  max_new_tokens=max_new_tokens,
 
157
  eos_token_id=self.tokenizer.eos_token_id,
158
  pad_token_id=self.tokenizer.pad_token_id
159
  )
160
+
 
 
 
161
  generated_text = self.tokenizer.decode(
162
  generated_ids[0],
163
  skip_special_tokens=True
164
  )
 
 
 
 
 
165
  return generated_text
166
 
167
  except Exception as e:
 
170
  traceback.print_exc()
171
  return f"Error: {str(e)}"
172
 
 
 
173
  model_instance = None
174
  app = Flask(__name__)
175
  app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
 
233
  </head>
234
  <body>
235
  <div class="container">
236
+ <h1> 模型在线推理</h1>
237
 
238
  <div>
239
  <label><strong>提示词 (Prompt):</strong></label>