import os import argparse from pathlib import Path import json from typing import Optional import torch from PIL import Image from transformers import AutoTokenizer # UI import gradio as gr from model import MultiModalDenseTransformer from continual_learning import UnifiedMultiModalPreprocessor os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" from torchvision import transforms image_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) class ModelInference: def __init__(self, checkpoint_path: str, tokenizer_name: str, config_path: Optional[str] = None, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): self.device = torch.device(device) print(f"Using device: {self.device}") print(f"Loading tokenizer: {tokenizer_name}...") self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True, trust_remote_code=True) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id if config_path and Path(config_path).exists(): with open(config_path, 'r') as f: self.config = json.load(f) else: self.config = { 'model_dim': 1536, 'vocab_size': len(self.tokenizer), 'n_layers': 12, 'n_heads': 12, 'n_kv_heads': 4, 'head_dim': None, 'max_seq_len': 512, 'dropout': 0.0, 'use_moe': False, 'use_adapter': False, 'use_lora': False, 'rope_scaling_type': "yarn", 'use_multimodal_fusion': False, 'use_contrastive': False } # init model + preprocessor print("Initializing model architecture...") self.model = MultiModalDenseTransformer(**self.config) self.preprocessor = UnifiedMultiModalPreprocessor(model_dim=self.config['model_dim']) print(f"Loading checkpoint from {checkpoint_path}...") checkpoint = torch.load(checkpoint_path, map_location=self.device) state_dict = checkpoint.get('model_state_dict', checkpoint) if isinstance(checkpoint, dict) else checkpoint new_state_dict = {} for k, v in state_dict.items(): if k.startswith('module.'): new_state_dict[k[7:]] = v else: new_state_dict[k] = v missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False) if missing: print(f"Warning: Missing keys: {len(missing)}") if unexpected: print(f"Warning: Unexpected keys: {len(unexpected)}") self.model.to(self.device) self.preprocessor.to(self.device) self.model.eval() print("Model loaded successfully!") print(f"Total parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M") @torch.no_grad() def generate_text(self, prompt: str, max_new_tokens: int = 128, temperature: float = 0.7, top_k: int = 10, top_p: float = 0.9, repetition_penalty: float = 1.2, image: Optional[Image.Image] = None) -> str: formatted_prompt = f"Instruction: {prompt}\nResponse:" inputs = self.tokenizer(formatted_prompt, return_tensors="pt") input_ids = inputs['input_ids'].to(self.device) input_data = {'segments': []} if image is not None: try: if image.mode != 'RGB': image = image.convert('RGB') image_tensor = image_transform(image).unsqueeze(0).to(self.device) mod_segments = self.preprocessor.process_batch(image_tensor, 'image') for seg in mod_segments: input_data['segments'].append(seg) except Exception as e: print(f"Warning: Image processing skipped due to error: {e}") input_data['segments'].append({ 'type': 'text', 'data': input_ids, 'modality_id': 0 }) try: generated_ids = self.model.generate( input_data, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id ) full_output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) # 提取 Response 后的部分并做 stop 处理 if "Response:" in full_output: answer = full_output.split("Response:")[-1].strip() else: answer = full_output stop_words = ["Instruction", "Input", "###", "Response", "User:", "Assistant:", "\n\n"] for sw in stop_words: if sw in answer: answer = answer.split(sw)[0].strip() # 去掉可能的 echo lines = answer.split('\n') if len(lines) > 0 and prompt.lower() in lines[0].lower(): answer = "\n".join(lines[1:]).strip() return answer except Exception as e: import traceback traceback.print_exc() return f"Error: {e}" def build_ui(model_instance): with gr.Blocks(title="MultiModal Dense Transformer - Gradio", css=""" .gradio-container { max-width: 900px; margin: auto; } """) as demo: gr.Markdown("## 多模态在线推理(文本 + 图片)") with gr.Row(): with gr.Column(scale=3): txt = gr.Textbox(label="Prompt (Instruction)", placeholder="请输入指令或问题...", lines=5) img = gr.Image(type="pil", label="(可选) 上传图片(支持多模态)") btn = gr.Button("生成 (Generate)") with gr.Column(scale=2): max_tokens = gr.Slider(label="Max New Tokens", minimum=16, maximum=1024, step=1, value=128) temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.01, value=0.7) top_k = gr.Slider(label="Top-k", minimum=0, maximum=200, step=1, value=40) top_p = gr.Slider(label="Top-p", minimum=0.0, maximum=1.0, step=0.01, value=0.9) rep_pen = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.01, value=1.1) status = gr.Textbox(label="Status", value="Ready", interactive=False) output = gr.Textbox(label="Output", lines=12, interactive=False) def gr_generate(prompt, image, max_tokens_v, temp_v, topk_v, topp_v, rep_v): if not prompt or str(prompt).strip() == "": return "", "请输入 Prompt", "" status_msg = "Generating..." # call model out = model_instance.generate_text(prompt=prompt, max_new_tokens=int(max_tokens_v), temperature=float(temp_v), top_k=int(topk_v), top_p=float(topp_v), repetition_penalty=float(rep_v), image=image) return out, "Done", "" btn.click(fn=gr_generate, inputs=[txt, img, max_tokens, temperature, top_k, top_p, rep_pen], outputs=[output, status, gr.State()]) demo.launch(share=True) return demo def main(): parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/posttrain/final_model.pt") parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-7B-Instruct") parser.add_argument("--config", type=str, default=None) parser.add_argument("--port", type=int, default=7860) parser.add_argument("--share", type=lambda x: x.lower() in ("true","1","yes"), default=True) args = parser.parse_args() if not Path(args.checkpoint).exists(): possible = list(Path("checkpoints/pretrain").glob("step_*.pt")) if possible: args.checkpoint = str(possible[-1]) print(f"未找到 final_model.pt,使用最新 checkpoint: {args.checkpoint}") else: raise FileNotFoundError(f"找不到检查点: {args.checkpoint}") global model_instance model_instance = ModelInference(args.checkpoint, args.tokenizer, args.config) demo = build_ui(model_instance) demo.launch(server_port=args.port, share=args.share) if __name__ == "__main__": main()