import os import torch import torch.nn.functional as F from flask import Flask, render_template, request, jsonify from transformers import AutoTokenizer from PIL import Image import json import io import base64 from pathlib import Path from typing import Optional from model import MultiModalDenseTransformer from continual_learning import UnifiedMultiModalPreprocessor from torchvision import transforms os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" image_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) class ModelInference: def __init__( self, checkpoint_path: str, tokenizer_name: str, config_path: Optional[str] = None, device: str = 'cuda' if torch.cuda.is_available() else 'cpu' ): self.device = torch.device(device) print(f"Using device: {self.device}") print(f"Loading tokenizer: {tokenizer_name}...") try: self.tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, use_fast=True, trust_remote_code=True ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id except Exception as e: print(f"Error loading tokenizer: {e}") raise e if config_path and Path(config_path).exists(): with open(config_path, 'r') as f: self.config = json.load(f) else: self.config = { 'model_dim': 1536, 'vocab_size': len(self.tokenizer), 'n_layers': 12, 'n_heads': 12, 'n_kv_heads': 4, 'head_dim': None, 'max_seq_len': 512, 'dropout': 0.0, 'use_moe': False, 'use_adapter': False, 'use_lora': False, 'rope_scaling_type': "yarn" } # 3. 初始化模型结构 print("Initializing model architecture...") try: self.model = MultiModalDenseTransformer(**self.config) self.preprocessor = UnifiedMultiModalPreprocessor( model_dim=self.config['model_dim'] ) # 4. 加载权重 print(f"Loading checkpoint from {checkpoint_path}...") checkpoint = torch.load( checkpoint_path, map_location=self.device, weights_only=False ) if 'model_state_dict' in checkpoint: print("Found 'model_state_dict' in checkpoint.") state_dict = checkpoint['model_state_dict'] else: state_dict = checkpoint new_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): new_state_dict[k[7:]] = v else: new_state_dict[k] = v missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False) if missing: print(f"Warning: Missing keys: {len(missing)}") if unexpected: print(f"Warning: Unexpected keys: {len(unexpected)}") self.model.to(self.device) self.preprocessor.to(self.device) self.model.eval() print("Model loaded successfully!") print(f"Total parameters: {sum(p.numel() for p in self.model.parameters()) / 1e6:.2f}M") except Exception as e: print(f"Error initializing model: {e}") raise e @torch.no_grad() def generate_text( self, prompt: str, max_new_tokens: int = 128, temperature: float = 0.7, top_k: int = 40, top_p: float = 0.9, repetition_penalty: float = 1.1, image: Optional[Image.Image] = None ) -> str: """生成文本""" inputs = self.tokenizer(prompt, return_tensors="pt") input_ids = inputs['input_ids'].to(self.device) input_data = {'segments': []} # 处理图像 if image is not None: if image.mode != 'RGB': image = image.convert('RGB') image_tensor = image_transform(image).unsqueeze(0).to(self.device) try: mod_segments = self.preprocessor.process_batch(image_tensor, 'image') for seg in mod_segments: input_data['segments'].append(seg) except Exception as e: print(f"Warning: Image processing skipped due to error: {e}") input_data['segments'].append({ 'type': 'text', 'data': input_ids, 'modality_id': 0 }) try: generated_ids = self.model.generate( input_data, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id ) generated_text = self.tokenizer.decode( generated_ids[0], skip_special_tokens=True ) return generated_text except Exception as e: print(f"Generation error: {e}") import traceback traceback.print_exc() return f"Error: {str(e)}" model_instance = None app = Flask(__name__) app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 @app.route('/') def index(): display_config = model_instance.config.copy() if model_instance else {} return render_template('index.html', config=display_config) @app.route('/generate', methods=['POST']) def generate(): try: data = request.json prompt = data.get('prompt', '') if not prompt.strip(): return jsonify({'error': '请输入提示文本'}), 400 max_tokens = int(data.get('max_tokens', 100)) temperature = float(data.get('temperature', 0.7)) top_k = int(data.get('top_k', 40)) top_p = float(data.get('top_p', 0.9)) repetition_penalty = float(data.get('repetition_penalty', 1.1)) image = None if 'image' in data and data['image']: try: image_data = base64.b64decode(data['image'].split(',')[1]) image = Image.open(io.BytesIO(image_data)) except Exception as e: print(f"Image load error: {e}") output = model_instance.generate_text( prompt, max_tokens, temperature, top_k, top_p, repetition_penalty, image ) return jsonify({'output': output}) except Exception as e: return jsonify({'error': str(e)}), 500 def create_html_template(): """写入HTML模板""" html_content = '''