Update infer.py
Browse files
infer.py
CHANGED
|
@@ -1,7 +1,3 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Flask推理界面 - 多模态Dense Transformer (适配 Qwen Tokenizer 版)
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
import os
|
| 6 |
import torch
|
| 7 |
import torch.nn.functional as F
|
|
@@ -14,18 +10,12 @@ import base64
|
|
| 14 |
from pathlib import Path
|
| 15 |
from typing import Optional
|
| 16 |
|
| 17 |
-
# 确保引入路径正确,根据你之前的文件结构
|
| 18 |
from model import MultiModalDenseTransformer
|
| 19 |
-
# 注意:UnifiedMultiModalPreprocessor 之前是在 continual_learning.py 中定义的
|
| 20 |
-
# 如果你移动了它,请修改这里的导入路径
|
| 21 |
from continual_learning import UnifiedMultiModalPreprocessor
|
| 22 |
-
# 如果没有 image_transform,我们需要在这里定义或导入
|
| 23 |
from torchvision import transforms
|
| 24 |
|
| 25 |
-
# 设置国内镜像
|
| 26 |
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 27 |
|
| 28 |
-
# 定义图像预处理 (与 training 保持一致)
|
| 29 |
image_transform = transforms.Compose([
|
| 30 |
transforms.Resize((224, 224)),
|
| 31 |
transforms.ToTensor(),
|
|
@@ -33,8 +23,6 @@ image_transform = transforms.Compose([
|
|
| 33 |
])
|
| 34 |
|
| 35 |
class ModelInference:
|
| 36 |
-
"""模型推理类"""
|
| 37 |
-
|
| 38 |
def __init__(
|
| 39 |
self,
|
| 40 |
checkpoint_path: str,
|
|
@@ -44,8 +32,6 @@ class ModelInference:
|
|
| 44 |
):
|
| 45 |
self.device = torch.device(device)
|
| 46 |
print(f"Using device: {self.device}")
|
| 47 |
-
|
| 48 |
-
# 1. 加载 Tokenizer (与预训练一致)
|
| 49 |
print(f"Loading tokenizer: {tokenizer_name}...")
|
| 50 |
try:
|
| 51 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
@@ -59,26 +45,24 @@ class ModelInference:
|
|
| 59 |
except Exception as e:
|
| 60 |
print(f"Error loading tokenizer: {e}")
|
| 61 |
raise e
|
| 62 |
-
|
| 63 |
-
# 2. 配置模型参数 (必须与 pretrain.py 中的配置完全一致)
|
| 64 |
if config_path and Path(config_path).exists():
|
| 65 |
with open(config_path, 'r') as f:
|
| 66 |
self.config = json.load(f)
|
| 67 |
else:
|
| 68 |
-
# [CRITICAL] 这里使用了你在 pretrain.py 中使用的参数
|
| 69 |
self.config = {
|
| 70 |
-
'model_dim': 1536,
|
| 71 |
-
'vocab_size': len(self.tokenizer),
|
| 72 |
-
'n_layers': 12,
|
| 73 |
-
'n_heads': 12,
|
| 74 |
-
'n_kv_heads': 4,
|
| 75 |
-
'head_dim': None,
|
| 76 |
-
'max_seq_len': 512,
|
| 77 |
-
'dropout': 0.0,
|
| 78 |
-
'use_moe': False,
|
| 79 |
-
'use_adapter': False,
|
| 80 |
-
'use_lora': False,
|
| 81 |
-
'rope_scaling_type': "yarn"
|
| 82 |
}
|
| 83 |
|
| 84 |
# 3. 初始化模型结构
|
|
@@ -91,21 +75,18 @@ class ModelInference:
|
|
| 91 |
|
| 92 |
# 4. 加载权重
|
| 93 |
print(f"Loading checkpoint from {checkpoint_path}...")
|
| 94 |
-
# weights_only=False 是为了支持加载完整的 checkpoint 字典
|
| 95 |
checkpoint = torch.load(
|
| 96 |
checkpoint_path,
|
| 97 |
map_location=self.device,
|
| 98 |
weights_only=False
|
| 99 |
)
|
| 100 |
|
| 101 |
-
# 提取 state_dict
|
| 102 |
if 'model_state_dict' in checkpoint:
|
| 103 |
print("Found 'model_state_dict' in checkpoint.")
|
| 104 |
state_dict = checkpoint['model_state_dict']
|
| 105 |
else:
|
| 106 |
state_dict = checkpoint
|
| 107 |
|
| 108 |
-
# 处理可能的键名不匹配 (如 DDP 训练产生的 'module.' 前缀)
|
| 109 |
new_state_dict = {}
|
| 110 |
for k, v in state_dict.items():
|
| 111 |
if k.startswith('module.'):
|
|
@@ -113,7 +94,6 @@ class ModelInference:
|
|
| 113 |
else:
|
| 114 |
new_state_dict[k] = v
|
| 115 |
|
| 116 |
-
# 加载权重 (strict=False 允许忽略一些非关键的不匹配,如 loss 缓存等)
|
| 117 |
missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
|
| 118 |
if missing:
|
| 119 |
print(f"Warning: Missing keys: {len(missing)}")
|
|
@@ -143,40 +123,29 @@ class ModelInference:
|
|
| 143 |
image: Optional[Image.Image] = None
|
| 144 |
) -> str:
|
| 145 |
"""生成文本"""
|
| 146 |
-
|
| 147 |
-
# 编码输入
|
| 148 |
inputs = self.tokenizer(prompt, return_tensors="pt")
|
| 149 |
input_ids = inputs['input_ids'].to(self.device)
|
| 150 |
-
|
| 151 |
-
# 构建 MultiModalDenseTransformer 需要的输入格式
|
| 152 |
input_data = {'segments': []}
|
| 153 |
|
| 154 |
# 处理图像
|
| 155 |
if image is not None:
|
| 156 |
if image.mode != 'RGB':
|
| 157 |
image = image.convert('RGB')
|
| 158 |
-
# 简单的图像处理
|
| 159 |
image_tensor = image_transform(image).unsqueeze(0).to(self.device)
|
| 160 |
-
# 这里假设预处理器能处理这种输入
|
| 161 |
try:
|
| 162 |
-
# process_batch 接受 (batch_data, modality_type) 并返回 segments 列表
|
| 163 |
mod_segments = self.preprocessor.process_batch(image_tensor, 'image')
|
| 164 |
-
# 将返回的 segment 列表合并到 input_data
|
| 165 |
for seg in mod_segments:
|
| 166 |
input_data['segments'].append(seg)
|
| 167 |
except Exception as e:
|
| 168 |
print(f"Warning: Image processing skipped due to error: {e}")
|
| 169 |
-
|
| 170 |
-
# 添加文本段
|
| 171 |
input_data['segments'].append({
|
| 172 |
'type': 'text',
|
| 173 |
'data': input_ids,
|
| 174 |
'modality_id': 0
|
| 175 |
})
|
| 176 |
|
| 177 |
-
# 生成
|
| 178 |
try:
|
| 179 |
-
# 使用模型自带的 generate 方法
|
| 180 |
generated_ids = self.model.generate(
|
| 181 |
input_data,
|
| 182 |
max_new_tokens=max_new_tokens,
|
|
@@ -188,19 +157,11 @@ class ModelInference:
|
|
| 188 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 189 |
pad_token_id=self.tokenizer.pad_token_id
|
| 190 |
)
|
| 191 |
-
|
| 192 |
-
# 解码
|
| 193 |
-
# 注意:生成的 ids 可能包含原始输入,或者只包含新生成的 token
|
| 194 |
-
# MultiModalDenseTransformer.generate 通常返回完整的序列
|
| 195 |
generated_text = self.tokenizer.decode(
|
| 196 |
generated_ids[0],
|
| 197 |
skip_special_tokens=True
|
| 198 |
)
|
| 199 |
-
|
| 200 |
-
# 如果包含 prompt,可以选择移除它只显示新内容
|
| 201 |
-
# if generated_text.startswith(prompt):
|
| 202 |
-
# generated_text = generated_text[len(prompt):]
|
| 203 |
-
|
| 204 |
return generated_text
|
| 205 |
|
| 206 |
except Exception as e:
|
|
@@ -209,8 +170,6 @@ class ModelInference:
|
|
| 209 |
traceback.print_exc()
|
| 210 |
return f"Error: {str(e)}"
|
| 211 |
|
| 212 |
-
|
| 213 |
-
# 全局模型实例
|
| 214 |
model_instance = None
|
| 215 |
app = Flask(__name__)
|
| 216 |
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
|
|
@@ -274,7 +233,7 @@ def create_html_template():
|
|
| 274 |
</head>
|
| 275 |
<body>
|
| 276 |
<div class="container">
|
| 277 |
-
<h1
|
| 278 |
|
| 279 |
<div>
|
| 280 |
<label><strong>提示词 (Prompt):</strong></label>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
|
|
|
| 10 |
from pathlib import Path
|
| 11 |
from typing import Optional
|
| 12 |
|
|
|
|
| 13 |
from model import MultiModalDenseTransformer
|
|
|
|
|
|
|
| 14 |
from continual_learning import UnifiedMultiModalPreprocessor
|
|
|
|
| 15 |
from torchvision import transforms
|
| 16 |
|
|
|
|
| 17 |
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| 18 |
|
|
|
|
| 19 |
image_transform = transforms.Compose([
|
| 20 |
transforms.Resize((224, 224)),
|
| 21 |
transforms.ToTensor(),
|
|
|
|
| 23 |
])
|
| 24 |
|
| 25 |
class ModelInference:
|
|
|
|
|
|
|
| 26 |
def __init__(
|
| 27 |
self,
|
| 28 |
checkpoint_path: str,
|
|
|
|
| 32 |
):
|
| 33 |
self.device = torch.device(device)
|
| 34 |
print(f"Using device: {self.device}")
|
|
|
|
|
|
|
| 35 |
print(f"Loading tokenizer: {tokenizer_name}...")
|
| 36 |
try:
|
| 37 |
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
| 45 |
except Exception as e:
|
| 46 |
print(f"Error loading tokenizer: {e}")
|
| 47 |
raise e
|
| 48 |
+
|
|
|
|
| 49 |
if config_path and Path(config_path).exists():
|
| 50 |
with open(config_path, 'r') as f:
|
| 51 |
self.config = json.load(f)
|
| 52 |
else:
|
|
|
|
| 53 |
self.config = {
|
| 54 |
+
'model_dim': 1536,
|
| 55 |
+
'vocab_size': len(self.tokenizer),
|
| 56 |
+
'n_layers': 12,
|
| 57 |
+
'n_heads': 12,
|
| 58 |
+
'n_kv_heads': 4,
|
| 59 |
+
'head_dim': None,
|
| 60 |
+
'max_seq_len': 512,
|
| 61 |
+
'dropout': 0.0,
|
| 62 |
+
'use_moe': False,
|
| 63 |
+
'use_adapter': False,
|
| 64 |
+
'use_lora': False,
|
| 65 |
+
'rope_scaling_type': "yarn"
|
| 66 |
}
|
| 67 |
|
| 68 |
# 3. 初始化模型结构
|
|
|
|
| 75 |
|
| 76 |
# 4. 加载权重
|
| 77 |
print(f"Loading checkpoint from {checkpoint_path}...")
|
|
|
|
| 78 |
checkpoint = torch.load(
|
| 79 |
checkpoint_path,
|
| 80 |
map_location=self.device,
|
| 81 |
weights_only=False
|
| 82 |
)
|
| 83 |
|
|
|
|
| 84 |
if 'model_state_dict' in checkpoint:
|
| 85 |
print("Found 'model_state_dict' in checkpoint.")
|
| 86 |
state_dict = checkpoint['model_state_dict']
|
| 87 |
else:
|
| 88 |
state_dict = checkpoint
|
| 89 |
|
|
|
|
| 90 |
new_state_dict = {}
|
| 91 |
for k, v in state_dict.items():
|
| 92 |
if k.startswith('module.'):
|
|
|
|
| 94 |
else:
|
| 95 |
new_state_dict[k] = v
|
| 96 |
|
|
|
|
| 97 |
missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
|
| 98 |
if missing:
|
| 99 |
print(f"Warning: Missing keys: {len(missing)}")
|
|
|
|
| 123 |
image: Optional[Image.Image] = None
|
| 124 |
) -> str:
|
| 125 |
"""生成文本"""
|
|
|
|
|
|
|
| 126 |
inputs = self.tokenizer(prompt, return_tensors="pt")
|
| 127 |
input_ids = inputs['input_ids'].to(self.device)
|
|
|
|
|
|
|
| 128 |
input_data = {'segments': []}
|
| 129 |
|
| 130 |
# 处理图像
|
| 131 |
if image is not None:
|
| 132 |
if image.mode != 'RGB':
|
| 133 |
image = image.convert('RGB')
|
|
|
|
| 134 |
image_tensor = image_transform(image).unsqueeze(0).to(self.device)
|
|
|
|
| 135 |
try:
|
|
|
|
| 136 |
mod_segments = self.preprocessor.process_batch(image_tensor, 'image')
|
|
|
|
| 137 |
for seg in mod_segments:
|
| 138 |
input_data['segments'].append(seg)
|
| 139 |
except Exception as e:
|
| 140 |
print(f"Warning: Image processing skipped due to error: {e}")
|
| 141 |
+
|
|
|
|
| 142 |
input_data['segments'].append({
|
| 143 |
'type': 'text',
|
| 144 |
'data': input_ids,
|
| 145 |
'modality_id': 0
|
| 146 |
})
|
| 147 |
|
|
|
|
| 148 |
try:
|
|
|
|
| 149 |
generated_ids = self.model.generate(
|
| 150 |
input_data,
|
| 151 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 157 |
eos_token_id=self.tokenizer.eos_token_id,
|
| 158 |
pad_token_id=self.tokenizer.pad_token_id
|
| 159 |
)
|
| 160 |
+
|
|
|
|
|
|
|
|
|
|
| 161 |
generated_text = self.tokenizer.decode(
|
| 162 |
generated_ids[0],
|
| 163 |
skip_special_tokens=True
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
return generated_text
|
| 166 |
|
| 167 |
except Exception as e:
|
|
|
|
| 170 |
traceback.print_exc()
|
| 171 |
return f"Error: {str(e)}"
|
| 172 |
|
|
|
|
|
|
|
| 173 |
model_instance = None
|
| 174 |
app = Flask(__name__)
|
| 175 |
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
|
|
|
|
| 233 |
</head>
|
| 234 |
<body>
|
| 235 |
<div class="container">
|
| 236 |
+
<h1> 模型在线推理</h1>
|
| 237 |
|
| 238 |
<div>
|
| 239 |
<label><strong>提示词 (Prompt):</strong></label>
|