| import os |
| import torch |
| import torch.nn.functional as F |
| from flask import Flask, render_template, request, jsonify |
| from transformers import AutoTokenizer |
| from PIL import Image |
| import json |
| import io |
| import base64 |
| from pathlib import Path |
| from typing import Optional |
|
|
| from model import MultiModalDenseTransformer |
| from continual_learning import UnifiedMultiModalPreprocessor |
| from torchvision import transforms |
|
|
| os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" |
|
|
| image_transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
| class ModelInference: |
| def __init__( |
| self, |
| checkpoint_path: str, |
| tokenizer_name: str, |
| config_path: Optional[str] = None, |
| device: str = 'cuda' if torch.cuda.is_available() else 'cpu' |
| ): |
| self.device = torch.device(device) |
| print(f"Using device: {self.device}") |
| print(f"Loading tokenizer: {tokenizer_name}...") |
| try: |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| tokenizer_name, |
| use_fast=True, |
| trust_remote_code=True |
| ) |
| if self.tokenizer.pad_token is None: |
| self.tokenizer.pad_token = self.tokenizer.eos_token |
| self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
| except Exception as e: |
| print(f"Error loading tokenizer: {e}") |
| raise e |
|
|
| if config_path and Path(config_path).exists(): |
| with open(config_path, 'r') as f: |
| self.config = json.load(f) |
| else: |
| self.config = { |
| 'model_dim': 1536, |
| 'vocab_size': len(self.tokenizer), |
| 'n_layers': 12, |
| 'n_heads': 12, |
| 'n_kv_heads': 4, |
| 'head_dim': None, |
| 'max_seq_len': 512, |
| 'dropout': 0.0, |
| 'use_moe': False, |
| 'use_adapter': False, |
| 'use_lora': False, |
| 'rope_scaling_type': "yarn" |
| } |
| |
| |
| print("Initializing model architecture...") |
| try: |
| self.model = MultiModalDenseTransformer(**self.config) |
| self.preprocessor = UnifiedMultiModalPreprocessor( |
| model_dim=self.config['model_dim'] |
| ) |
| |
| |
| print(f"Loading checkpoint from {checkpoint_path}...") |
| checkpoint = torch.load( |
| checkpoint_path, |
| map_location=self.device, |
| weights_only=False |
| ) |
| |
| if 'model_state_dict' in checkpoint: |
| print("Found 'model_state_dict' in checkpoint.") |
| state_dict = checkpoint['model_state_dict'] |
| else: |
| state_dict = checkpoint |
| |
| new_state_dict = {} |
| for k, v in state_dict.items(): |
| if k.startswith('module.'): |
| new_state_dict[k[7:]] = v |
| else: |
| new_state_dict[k] = v |
| |
| missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False) |
| if missing: |
| print(f"Warning: Missing keys: {len(missing)}") |
| if unexpected: |
| print(f"Warning: Unexpected keys: {len(unexpected)}") |
| |
| self.model.to(self.device) |
| self.preprocessor.to(self.device) |
| self.model.eval() |
| |
| print("Model loaded successfully!") |
| print(f"Total parameters: {sum(p.numel() for p in self.model.parameters()) / 1e6:.2f}M") |
| |
| except Exception as e: |
| print(f"Error initializing model: {e}") |
| raise e |
| |
| @torch.no_grad() |
| def generate_text( |
| self, |
| prompt: str, |
| max_new_tokens: int = 128, |
| temperature: float = 0.7, |
| top_k: int = 40, |
| top_p: float = 0.9, |
| repetition_penalty: float = 1.1, |
| image: Optional[Image.Image] = None |
| ) -> str: |
| """生成文本""" |
| inputs = self.tokenizer(prompt, return_tensors="pt") |
| input_ids = inputs['input_ids'].to(self.device) |
| input_data = {'segments': []} |
| |
| |
| if image is not None: |
| if image.mode != 'RGB': |
| image = image.convert('RGB') |
| image_tensor = image_transform(image).unsqueeze(0).to(self.device) |
| try: |
| mod_segments = self.preprocessor.process_batch(image_tensor, 'image') |
| for seg in mod_segments: |
| input_data['segments'].append(seg) |
| except Exception as e: |
| print(f"Warning: Image processing skipped due to error: {e}") |
| |
| input_data['segments'].append({ |
| 'type': 'text', |
| 'data': input_ids, |
| 'modality_id': 0 |
| }) |
| |
| try: |
| generated_ids = self.model.generate( |
| input_data, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| do_sample=True, |
| eos_token_id=self.tokenizer.eos_token_id, |
| pad_token_id=self.tokenizer.pad_token_id |
| ) |
|
|
| generated_text = self.tokenizer.decode( |
| generated_ids[0], |
| skip_special_tokens=True |
| ) |
| return generated_text |
| |
| except Exception as e: |
| print(f"Generation error: {e}") |
| import traceback |
| traceback.print_exc() |
| return f"Error: {str(e)}" |
|
|
| model_instance = None |
| app = Flask(__name__) |
| app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 |
|
|
| @app.route('/') |
| def index(): |
| display_config = model_instance.config.copy() if model_instance else {} |
| return render_template('index.html', config=display_config) |
|
|
| @app.route('/generate', methods=['POST']) |
| def generate(): |
| try: |
| data = request.json |
| prompt = data.get('prompt', '') |
| if not prompt.strip(): |
| return jsonify({'error': '请输入提示文本'}), 400 |
| |
| max_tokens = int(data.get('max_tokens', 100)) |
| temperature = float(data.get('temperature', 0.7)) |
| top_k = int(data.get('top_k', 40)) |
| top_p = float(data.get('top_p', 0.9)) |
| repetition_penalty = float(data.get('repetition_penalty', 1.1)) |
| |
| image = None |
| if 'image' in data and data['image']: |
| try: |
| image_data = base64.b64decode(data['image'].split(',')[1]) |
| image = Image.open(io.BytesIO(image_data)) |
| except Exception as e: |
| print(f"Image load error: {e}") |
| |
| output = model_instance.generate_text( |
| prompt, max_tokens, temperature, top_k, top_p, repetition_penalty, image |
| ) |
| return jsonify({'output': output}) |
| |
| except Exception as e: |
| return jsonify({'error': str(e)}), 500 |
|
|
| def create_html_template(): |
| """写入HTML模板""" |
| html_content = ''' |
| <!DOCTYPE html> |
| <html lang="zh-CN"> |
| <head> |
| <meta charset="UTF-8"> |
| <meta name="viewport" content="width=device-width, initial-scale=1.0"> |
| <title>Model Inference</title> |
| <style> |
| body { font-family: sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; background: #f0f2f5; } |
| .container { background: white; padding: 30px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); } |
| h1 { color: #1a73e8; text-align: center; } |
| textarea { width: 100%; padding: 10px; border: 1px solid #ddd; border-radius: 8px; margin: 10px 0; min-height: 100px; } |
| .controls { display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin: 20px 0; background: #f8f9fa; padding: 15px; border-radius: 8px; } |
| button { background: #1a73e8; color: white; border: none; padding: 12px 24px; border-radius: 6px; cursor: pointer; width: 100%; font-size: 16px; transition: background 0.3s; } |
| button:hover { background: #1557b0; } |
| button:disabled { background: #ccc; } |
| #output { margin-top: 20px; padding: 20px; background: #f8f9fa; border-radius: 8px; white-space: pre-wrap; min-height: 100px; border: 1px solid #e0e0e0; } |
| .loading { color: #666; font-style: italic; } |
| </style> |
| </head> |
| <body> |
| <div class="container"> |
| <h1> 模型在线推理</h1> |
| |
| <div> |
| <label><strong>提示词 (Prompt):</strong></label> |
| <textarea id="prompt" placeholder="请输入你的问题..."></textarea> |
| </div> |
| |
| <div class="controls"> |
| <div> |
| <label>Max Tokens: <span id="maxTokensVal">128</span></label> |
| <input type="range" id="maxTokens" min="32" max="1024" value="128" style="width:100%" oninput="document.getElementById('maxTokensVal').innerText=this.value"> |
| </div> |
| <div> |
| <label>Temperature: <span id="tempVal">0.7</span></label> |
| <input type="range" id="temperature" min="0.1" max="1.5" step="0.1" value="0.7" style="width:100%" oninput="document.getElementById('tempVal').innerText=this.value"> |
| </div> |
| </div> |
| |
| <button id="btn" onclick="generate()">生成 (Generate)</button> |
| |
| <div id="output">结果将显示在这里...</div> |
| </div> |
| |
| <script> |
| async function generate() { |
| const prompt = document.getElementById('prompt').value; |
| if(!prompt) return alert("请输入内容"); |
| |
| const btn = document.getElementById('btn'); |
| const out = document.getElementById('output'); |
| |
| btn.disabled = true; |
| btn.innerText = "生成中..."; |
| out.innerHTML = '<div class="loading">正在思考中...</div>'; |
| |
| try { |
| const res = await fetch('/generate', { |
| method: 'POST', |
| headers: {'Content-Type': 'application/json'}, |
| body: JSON.stringify({ |
| prompt: prompt, |
| max_tokens: parseInt(document.getElementById('maxTokens').value), |
| temperature: parseFloat(document.getElementById('temperature').value) |
| }) |
| }); |
| const data = await res.json(); |
| if(data.error) out.innerText = "Error: " + data.error; |
| else out.innerText = data.output; |
| } catch(e) { |
| out.innerText = "请求失败: " + e; |
| } finally { |
| btn.disabled = false; |
| btn.innerText = "生成 (Generate)"; |
| } |
| } |
| </script> |
| </body> |
| </html> |
| ''' |
| |
| Path('templates').mkdir(exist_ok=True) |
| with open('templates/index.html', 'w', encoding='utf-8') as f: |
| f.write(html_content) |
|
|
| def main(): |
| import argparse |
| parser = argparse.ArgumentParser() |
| |
| parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt") |
| parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-7B-Instruct") |
| parser.add_argument("--port", type=int, default=5001) |
| parser.add_argument("--host", type=str, default="0.0.0.0") |
| args = parser.parse_args() |
| |
| if not Path(args.checkpoint).exists(): |
| |
| steps = list(Path("checkpoints/pretrain").glob("step_*.pt")) |
| if steps: |
| print(f"未找到 final_model.pt,尝试使用最新的 checkpoint: {steps[-1]}") |
| args.checkpoint = str(steps[-1]) |
| else: |
| print(f"错误: 找不到检查点文件: {args.checkpoint}") |
| return |
|
|
| create_html_template() |
| |
| global model_instance |
| model_instance = ModelInference(args.checkpoint, args.tokenizer) |
| |
| print(f"\n服务已启动: http://{args.host}:{args.port}") |
| app.run(host=args.host, port=args.port, |
| debug=True, |
| use_reloader=False) |
|
|
| if __name__ == "__main__": |
| main() |