|
|
import os |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from flask import Flask, render_template, request, jsonify |
|
|
from transformers import AutoTokenizer |
|
|
from PIL import Image |
|
|
import json |
|
|
import io |
|
|
import base64 |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
|
|
|
from model import MultiModalDenseTransformer |
|
|
from continual_learning import UnifiedMultiModalPreprocessor |
|
|
from torchvision import transforms |
|
|
|
|
|
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" |
|
|
|
|
|
image_transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
class ModelInference: |
|
|
def __init__( |
|
|
self, |
|
|
checkpoint_path: str, |
|
|
tokenizer_name: str, |
|
|
config_path: Optional[str] = None, |
|
|
device: str = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
): |
|
|
self.device = torch.device(device) |
|
|
print(f"Using device: {self.device}") |
|
|
print(f"Loading tokenizer: {tokenizer_name}...") |
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
tokenizer_name, |
|
|
use_fast=True, |
|
|
trust_remote_code=True |
|
|
) |
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
|
|
except Exception as e: |
|
|
print(f"Error loading tokenizer: {e}") |
|
|
raise e |
|
|
|
|
|
if config_path and Path(config_path).exists(): |
|
|
with open(config_path, 'r') as f: |
|
|
self.config = json.load(f) |
|
|
else: |
|
|
self.config = { |
|
|
'model_dim': 1536, |
|
|
'vocab_size': len(self.tokenizer), |
|
|
'n_layers': 12, |
|
|
'n_heads': 12, |
|
|
'n_kv_heads': 4, |
|
|
'head_dim': None, |
|
|
'max_seq_len': 512, |
|
|
'dropout': 0.0, |
|
|
'use_moe': False, |
|
|
'use_adapter': False, |
|
|
'use_lora': False, |
|
|
'rope_scaling_type': "yarn" |
|
|
} |
|
|
|
|
|
|
|
|
print("Initializing model architecture...") |
|
|
try: |
|
|
self.model = MultiModalDenseTransformer(**self.config) |
|
|
self.preprocessor = UnifiedMultiModalPreprocessor( |
|
|
model_dim=self.config['model_dim'] |
|
|
) |
|
|
|
|
|
|
|
|
print(f"Loading checkpoint from {checkpoint_path}...") |
|
|
checkpoint = torch.load( |
|
|
checkpoint_path, |
|
|
map_location=self.device, |
|
|
weights_only=False |
|
|
) |
|
|
|
|
|
if 'model_state_dict' in checkpoint: |
|
|
print("Found 'model_state_dict' in checkpoint.") |
|
|
state_dict = checkpoint['model_state_dict'] |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
|
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
if k.startswith('module.'): |
|
|
new_state_dict[k[7:]] = v |
|
|
else: |
|
|
new_state_dict[k] = v |
|
|
|
|
|
missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False) |
|
|
if missing: |
|
|
print(f"Warning: Missing keys: {len(missing)}") |
|
|
if unexpected: |
|
|
print(f"Warning: Unexpected keys: {len(unexpected)}") |
|
|
|
|
|
self.model.to(self.device) |
|
|
self.preprocessor.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
print("Model loaded successfully!") |
|
|
print(f"Total parameters: {sum(p.numel() for p in self.model.parameters()) / 1e6:.2f}M") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error initializing model: {e}") |
|
|
raise e |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_text( |
|
|
self, |
|
|
prompt: str, |
|
|
max_new_tokens: int = 128, |
|
|
temperature: float = 0.7, |
|
|
top_k: int = 40, |
|
|
top_p: float = 0.9, |
|
|
repetition_penalty: float = 1.1, |
|
|
image: Optional[Image.Image] = None |
|
|
) -> str: |
|
|
"""生成文本""" |
|
|
inputs = self.tokenizer(prompt, return_tensors="pt") |
|
|
input_ids = inputs['input_ids'].to(self.device) |
|
|
input_data = {'segments': []} |
|
|
|
|
|
|
|
|
if image is not None: |
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
image_tensor = image_transform(image).unsqueeze(0).to(self.device) |
|
|
try: |
|
|
mod_segments = self.preprocessor.process_batch(image_tensor, 'image') |
|
|
for seg in mod_segments: |
|
|
input_data['segments'].append(seg) |
|
|
except Exception as e: |
|
|
print(f"Warning: Image processing skipped due to error: {e}") |
|
|
|
|
|
input_data['segments'].append({ |
|
|
'type': 'text', |
|
|
'data': input_ids, |
|
|
'modality_id': 0 |
|
|
}) |
|
|
|
|
|
try: |
|
|
generated_ids = self.model.generate( |
|
|
input_data, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
do_sample=True, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
pad_token_id=self.tokenizer.pad_token_id |
|
|
) |
|
|
|
|
|
generated_text = self.tokenizer.decode( |
|
|
generated_ids[0], |
|
|
skip_special_tokens=True |
|
|
) |
|
|
return generated_text |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Generation error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
model_instance = None |
|
|
app = Flask(__name__) |
|
|
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 |
|
|
|
|
|
@app.route('/') |
|
|
def index(): |
|
|
display_config = model_instance.config.copy() if model_instance else {} |
|
|
return render_template('index.html', config=display_config) |
|
|
|
|
|
@app.route('/generate', methods=['POST']) |
|
|
def generate(): |
|
|
try: |
|
|
data = request.json |
|
|
prompt = data.get('prompt', '') |
|
|
if not prompt.strip(): |
|
|
return jsonify({'error': '请输入提示文本'}), 400 |
|
|
|
|
|
max_tokens = int(data.get('max_tokens', 100)) |
|
|
temperature = float(data.get('temperature', 0.7)) |
|
|
top_k = int(data.get('top_k', 40)) |
|
|
top_p = float(data.get('top_p', 0.9)) |
|
|
repetition_penalty = float(data.get('repetition_penalty', 1.1)) |
|
|
|
|
|
image = None |
|
|
if 'image' in data and data['image']: |
|
|
try: |
|
|
image_data = base64.b64decode(data['image'].split(',')[1]) |
|
|
image = Image.open(io.BytesIO(image_data)) |
|
|
except Exception as e: |
|
|
print(f"Image load error: {e}") |
|
|
|
|
|
output = model_instance.generate_text( |
|
|
prompt, max_tokens, temperature, top_k, top_p, repetition_penalty, image |
|
|
) |
|
|
return jsonify({'output': output}) |
|
|
|
|
|
except Exception as e: |
|
|
return jsonify({'error': str(e)}), 500 |
|
|
|
|
|
def create_html_template(): |
|
|
"""写入HTML模板""" |
|
|
html_content = ''' |
|
|
<!DOCTYPE html> |
|
|
<html lang="zh-CN"> |
|
|
<head> |
|
|
<meta charset="UTF-8"> |
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
|
<title>Model Inference</title> |
|
|
<style> |
|
|
body { font-family: sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; background: #f0f2f5; } |
|
|
.container { background: white; padding: 30px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); } |
|
|
h1 { color: #1a73e8; text-align: center; } |
|
|
textarea { width: 100%; padding: 10px; border: 1px solid #ddd; border-radius: 8px; margin: 10px 0; min-height: 100px; } |
|
|
.controls { display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin: 20px 0; background: #f8f9fa; padding: 15px; border-radius: 8px; } |
|
|
button { background: #1a73e8; color: white; border: none; padding: 12px 24px; border-radius: 6px; cursor: pointer; width: 100%; font-size: 16px; transition: background 0.3s; } |
|
|
button:hover { background: #1557b0; } |
|
|
button:disabled { background: #ccc; } |
|
|
#output { margin-top: 20px; padding: 20px; background: #f8f9fa; border-radius: 8px; white-space: pre-wrap; min-height: 100px; border: 1px solid #e0e0e0; } |
|
|
.loading { color: #666; font-style: italic; } |
|
|
</style> |
|
|
</head> |
|
|
<body> |
|
|
<div class="container"> |
|
|
<h1> 模型在线推理</h1> |
|
|
|
|
|
<div> |
|
|
<label><strong>提示词 (Prompt):</strong></label> |
|
|
<textarea id="prompt" placeholder="请输入你的问题..."></textarea> |
|
|
</div> |
|
|
|
|
|
<div class="controls"> |
|
|
<div> |
|
|
<label>Max Tokens: <span id="maxTokensVal">128</span></label> |
|
|
<input type="range" id="maxTokens" min="32" max="1024" value="128" style="width:100%" oninput="document.getElementById('maxTokensVal').innerText=this.value"> |
|
|
</div> |
|
|
<div> |
|
|
<label>Temperature: <span id="tempVal">0.7</span></label> |
|
|
<input type="range" id="temperature" min="0.1" max="1.5" step="0.1" value="0.7" style="width:100%" oninput="document.getElementById('tempVal').innerText=this.value"> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
<button id="btn" onclick="generate()">生成 (Generate)</button> |
|
|
|
|
|
<div id="output">结果将显示在这里...</div> |
|
|
</div> |
|
|
|
|
|
<script> |
|
|
async function generate() { |
|
|
const prompt = document.getElementById('prompt').value; |
|
|
if(!prompt) return alert("请输入内容"); |
|
|
|
|
|
const btn = document.getElementById('btn'); |
|
|
const out = document.getElementById('output'); |
|
|
|
|
|
btn.disabled = true; |
|
|
btn.innerText = "生成中..."; |
|
|
out.innerHTML = '<div class="loading">正在思考中...</div>'; |
|
|
|
|
|
try { |
|
|
const res = await fetch('/generate', { |
|
|
method: 'POST', |
|
|
headers: {'Content-Type': 'application/json'}, |
|
|
body: JSON.stringify({ |
|
|
prompt: prompt, |
|
|
max_tokens: parseInt(document.getElementById('maxTokens').value), |
|
|
temperature: parseFloat(document.getElementById('temperature').value) |
|
|
}) |
|
|
}); |
|
|
const data = await res.json(); |
|
|
if(data.error) out.innerText = "Error: " + data.error; |
|
|
else out.innerText = data.output; |
|
|
} catch(e) { |
|
|
out.innerText = "请求失败: " + e; |
|
|
} finally { |
|
|
btn.disabled = false; |
|
|
btn.innerText = "生成 (Generate)"; |
|
|
} |
|
|
} |
|
|
</script> |
|
|
</body> |
|
|
</html> |
|
|
''' |
|
|
|
|
|
Path('templates').mkdir(exist_ok=True) |
|
|
with open('templates/index.html', 'w', encoding='utf-8') as f: |
|
|
f.write(html_content) |
|
|
|
|
|
def main(): |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt") |
|
|
parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-7B-Instruct") |
|
|
parser.add_argument("--port", type=int, default=5001) |
|
|
parser.add_argument("--host", type=str, default="0.0.0.0") |
|
|
args = parser.parse_args() |
|
|
|
|
|
if not Path(args.checkpoint).exists(): |
|
|
|
|
|
steps = list(Path("checkpoints/pretrain").glob("step_*.pt")) |
|
|
if steps: |
|
|
print(f"未找到 final_model.pt,尝试使用最新的 checkpoint: {steps[-1]}") |
|
|
args.checkpoint = str(steps[-1]) |
|
|
else: |
|
|
print(f"错误: 找不到检查点文件: {args.checkpoint}") |
|
|
return |
|
|
|
|
|
create_html_template() |
|
|
|
|
|
global model_instance |
|
|
model_instance = ModelInference(args.checkpoint, args.tokenizer) |
|
|
|
|
|
print(f"\n服务已启动: http://{args.host}:{args.port}") |
|
|
app.run(host=args.host, port=args.port, |
|
|
debug=True, |
|
|
use_reloader=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |