MultiModal / infer.py
szxllm's picture
Update infer.py
9c85325 verified
import os
import torch
import torch.nn.functional as F
from flask import Flask, render_template, request, jsonify
from transformers import AutoTokenizer
from PIL import Image
import json
import io
import base64
from pathlib import Path
from typing import Optional
from model import MultiModalDenseTransformer
from continual_learning import UnifiedMultiModalPreprocessor
from torchvision import transforms
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
image_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class ModelInference:
def __init__(
self,
checkpoint_path: str,
tokenizer_name: str,
config_path: Optional[str] = None,
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
):
self.device = torch.device(device)
print(f"Using device: {self.device}")
print(f"Loading tokenizer: {tokenizer_name}...")
try:
self.tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
use_fast=True,
trust_remote_code=True
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
except Exception as e:
print(f"Error loading tokenizer: {e}")
raise e
if config_path and Path(config_path).exists():
with open(config_path, 'r') as f:
self.config = json.load(f)
else:
self.config = {
'model_dim': 1536,
'vocab_size': len(self.tokenizer),
'n_layers': 12,
'n_heads': 12,
'n_kv_heads': 4,
'head_dim': None,
'max_seq_len': 512,
'dropout': 0.0,
'use_moe': False,
'use_adapter': False,
'use_lora': False,
'rope_scaling_type': "yarn"
}
# 3. 初始化模型结构
print("Initializing model architecture...")
try:
self.model = MultiModalDenseTransformer(**self.config)
self.preprocessor = UnifiedMultiModalPreprocessor(
model_dim=self.config['model_dim']
)
# 4. 加载权重
print(f"Loading checkpoint from {checkpoint_path}...")
checkpoint = torch.load(
checkpoint_path,
map_location=self.device,
weights_only=False
)
if 'model_state_dict' in checkpoint:
print("Found 'model_state_dict' in checkpoint.")
state_dict = checkpoint['model_state_dict']
else:
state_dict = checkpoint
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
if missing:
print(f"Warning: Missing keys: {len(missing)}")
if unexpected:
print(f"Warning: Unexpected keys: {len(unexpected)}")
self.model.to(self.device)
self.preprocessor.to(self.device)
self.model.eval()
print("Model loaded successfully!")
print(f"Total parameters: {sum(p.numel() for p in self.model.parameters()) / 1e6:.2f}M")
except Exception as e:
print(f"Error initializing model: {e}")
raise e
@torch.no_grad()
def generate_text(
self,
prompt: str,
max_new_tokens: int = 128,
temperature: float = 0.7,
top_k: int = 40,
top_p: float = 0.9,
repetition_penalty: float = 1.1,
image: Optional[Image.Image] = None
) -> str:
"""生成文本"""
inputs = self.tokenizer(prompt, return_tensors="pt")
input_ids = inputs['input_ids'].to(self.device)
input_data = {'segments': []}
# 处理图像
if image is not None:
if image.mode != 'RGB':
image = image.convert('RGB')
image_tensor = image_transform(image).unsqueeze(0).to(self.device)
try:
mod_segments = self.preprocessor.process_batch(image_tensor, 'image')
for seg in mod_segments:
input_data['segments'].append(seg)
except Exception as e:
print(f"Warning: Image processing skipped due to error: {e}")
input_data['segments'].append({
'type': 'text',
'data': input_ids,
'modality_id': 0
})
try:
generated_ids = self.model.generate(
input_data,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id
)
generated_text = self.tokenizer.decode(
generated_ids[0],
skip_special_tokens=True
)
return generated_text
except Exception as e:
print(f"Generation error: {e}")
import traceback
traceback.print_exc()
return f"Error: {str(e)}"
model_instance = None
app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
@app.route('/')
def index():
display_config = model_instance.config.copy() if model_instance else {}
return render_template('index.html', config=display_config)
@app.route('/generate', methods=['POST'])
def generate():
try:
data = request.json
prompt = data.get('prompt', '')
if not prompt.strip():
return jsonify({'error': '请输入提示文本'}), 400
max_tokens = int(data.get('max_tokens', 100))
temperature = float(data.get('temperature', 0.7))
top_k = int(data.get('top_k', 40))
top_p = float(data.get('top_p', 0.9))
repetition_penalty = float(data.get('repetition_penalty', 1.1))
image = None
if 'image' in data and data['image']:
try:
image_data = base64.b64decode(data['image'].split(',')[1])
image = Image.open(io.BytesIO(image_data))
except Exception as e:
print(f"Image load error: {e}")
output = model_instance.generate_text(
prompt, max_tokens, temperature, top_k, top_p, repetition_penalty, image
)
return jsonify({'output': output})
except Exception as e:
return jsonify({'error': str(e)}), 500
def create_html_template():
"""写入HTML模板"""
html_content = '''
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Model Inference</title>
<style>
body { font-family: sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; background: #f0f2f5; }
.container { background: white; padding: 30px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); }
h1 { color: #1a73e8; text-align: center; }
textarea { width: 100%; padding: 10px; border: 1px solid #ddd; border-radius: 8px; margin: 10px 0; min-height: 100px; }
.controls { display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin: 20px 0; background: #f8f9fa; padding: 15px; border-radius: 8px; }
button { background: #1a73e8; color: white; border: none; padding: 12px 24px; border-radius: 6px; cursor: pointer; width: 100%; font-size: 16px; transition: background 0.3s; }
button:hover { background: #1557b0; }
button:disabled { background: #ccc; }
#output { margin-top: 20px; padding: 20px; background: #f8f9fa; border-radius: 8px; white-space: pre-wrap; min-height: 100px; border: 1px solid #e0e0e0; }
.loading { color: #666; font-style: italic; }
</style>
</head>
<body>
<div class="container">
<h1> 模型在线推理</h1>
<div>
<label><strong>提示词 (Prompt):</strong></label>
<textarea id="prompt" placeholder="请输入你的问题..."></textarea>
</div>
<div class="controls">
<div>
<label>Max Tokens: <span id="maxTokensVal">128</span></label>
<input type="range" id="maxTokens" min="32" max="1024" value="128" style="width:100%" oninput="document.getElementById('maxTokensVal').innerText=this.value">
</div>
<div>
<label>Temperature: <span id="tempVal">0.7</span></label>
<input type="range" id="temperature" min="0.1" max="1.5" step="0.1" value="0.7" style="width:100%" oninput="document.getElementById('tempVal').innerText=this.value">
</div>
</div>
<button id="btn" onclick="generate()">生成 (Generate)</button>
<div id="output">结果将显示在这里...</div>
</div>
<script>
async function generate() {
const prompt = document.getElementById('prompt').value;
if(!prompt) return alert("请输入内容");
const btn = document.getElementById('btn');
const out = document.getElementById('output');
btn.disabled = true;
btn.innerText = "生成中...";
out.innerHTML = '<div class="loading">正在思考中...</div>';
try {
const res = await fetch('/generate', {
method: 'POST',
headers: {'Content-Type': 'application/json'},
body: JSON.stringify({
prompt: prompt,
max_tokens: parseInt(document.getElementById('maxTokens').value),
temperature: parseFloat(document.getElementById('temperature').value)
})
});
const data = await res.json();
if(data.error) out.innerText = "Error: " + data.error;
else out.innerText = data.output;
} catch(e) {
out.innerText = "请求失败: " + e;
} finally {
btn.disabled = false;
btn.innerText = "生成 (Generate)";
}
}
</script>
</body>
</html>
'''
Path('templates').mkdir(exist_ok=True)
with open('templates/index.html', 'w', encoding='utf-8') as f:
f.write(html_content)
def main():
import argparse
parser = argparse.ArgumentParser()
# 默认指向 pretrain 保存的 checkpoint 路径
parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt")
parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-7B-Instruct")
parser.add_argument("--port", type=int, default=5001)
parser.add_argument("--host", type=str, default="0.0.0.0")
args = parser.parse_args()
if not Path(args.checkpoint).exists():
# 尝试找最近的 step checkpoint
steps = list(Path("checkpoints/pretrain").glob("step_*.pt"))
if steps:
print(f"未找到 final_model.pt,尝试使用最新的 checkpoint: {steps[-1]}")
args.checkpoint = str(steps[-1])
else:
print(f"错误: 找不到检查点文件: {args.checkpoint}")
return
create_html_template()
global model_instance
model_instance = ModelInference(args.checkpoint, args.tokenizer)
print(f"\n服务已启动: http://{args.host}:{args.port}")
app.run(host=args.host, port=args.port,
debug=True, # 开启调试模式
use_reloader=False)
if __name__ == "__main__":
main()