|
|
import os |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
import json |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from PIL import Image |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
from model import MultiModalDenseTransformer |
|
|
from continual_learning import UnifiedMultiModalPreprocessor |
|
|
|
|
|
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com" |
|
|
|
|
|
from torchvision import transforms |
|
|
image_transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
class ModelInference: |
|
|
def __init__(self, checkpoint_path: str, tokenizer_name: str, config_path: Optional[str] = None, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'): |
|
|
self.device = torch.device(device) |
|
|
print(f"Using device: {self.device}") |
|
|
print(f"Loading tokenizer: {tokenizer_name}...") |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True, trust_remote_code=True) |
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
|
|
|
|
|
if config_path and Path(config_path).exists(): |
|
|
with open(config_path, 'r') as f: |
|
|
self.config = json.load(f) |
|
|
else: |
|
|
self.config = { |
|
|
'model_dim': 1536, |
|
|
'vocab_size': len(self.tokenizer), |
|
|
'n_layers': 12, |
|
|
'n_heads': 12, |
|
|
'n_kv_heads': 4, |
|
|
'head_dim': None, |
|
|
'max_seq_len': 512, |
|
|
'dropout': 0.0, |
|
|
'use_moe': False, |
|
|
'use_adapter': False, |
|
|
'use_lora': False, |
|
|
'rope_scaling_type': "yarn", |
|
|
'use_multimodal_fusion': False, |
|
|
'use_contrastive': False |
|
|
} |
|
|
|
|
|
|
|
|
print("Initializing model architecture...") |
|
|
self.model = MultiModalDenseTransformer(**self.config) |
|
|
self.preprocessor = UnifiedMultiModalPreprocessor(model_dim=self.config['model_dim']) |
|
|
|
|
|
print(f"Loading checkpoint from {checkpoint_path}...") |
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device) |
|
|
state_dict = checkpoint.get('model_state_dict', checkpoint) if isinstance(checkpoint, dict) else checkpoint |
|
|
|
|
|
new_state_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
if k.startswith('module.'): |
|
|
new_state_dict[k[7:]] = v |
|
|
else: |
|
|
new_state_dict[k] = v |
|
|
|
|
|
missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False) |
|
|
if missing: |
|
|
print(f"Warning: Missing keys: {len(missing)}") |
|
|
if unexpected: |
|
|
print(f"Warning: Unexpected keys: {len(unexpected)}") |
|
|
|
|
|
self.model.to(self.device) |
|
|
self.preprocessor.to(self.device) |
|
|
self.model.eval() |
|
|
print("Model loaded successfully!") |
|
|
print(f"Total parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M") |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_text(self, prompt: str, max_new_tokens: int = 128, temperature: float = 0.7, top_k: int = 10, top_p: float = 0.9, repetition_penalty: float = 1.2, image: Optional[Image.Image] = None) -> str: |
|
|
formatted_prompt = f"Instruction: {prompt}\nResponse:" |
|
|
inputs = self.tokenizer(formatted_prompt, return_tensors="pt") |
|
|
input_ids = inputs['input_ids'].to(self.device) |
|
|
|
|
|
input_data = {'segments': []} |
|
|
if image is not None: |
|
|
try: |
|
|
if image.mode != 'RGB': |
|
|
image = image.convert('RGB') |
|
|
image_tensor = image_transform(image).unsqueeze(0).to(self.device) |
|
|
mod_segments = self.preprocessor.process_batch(image_tensor, 'image') |
|
|
for seg in mod_segments: |
|
|
input_data['segments'].append(seg) |
|
|
except Exception as e: |
|
|
print(f"Warning: Image processing skipped due to error: {e}") |
|
|
|
|
|
input_data['segments'].append({ |
|
|
'type': 'text', |
|
|
'data': input_ids, |
|
|
'modality_id': 0 |
|
|
}) |
|
|
|
|
|
try: |
|
|
generated_ids = self.model.generate( |
|
|
input_data, |
|
|
max_new_tokens=max_new_tokens, |
|
|
temperature=temperature, |
|
|
top_k=top_k, |
|
|
top_p=top_p, |
|
|
repetition_penalty=repetition_penalty, |
|
|
do_sample=True, |
|
|
eos_token_id=self.tokenizer.eos_token_id, |
|
|
pad_token_id=self.tokenizer.pad_token_id |
|
|
) |
|
|
|
|
|
full_output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
|
|
|
|
|
if "Response:" in full_output: |
|
|
answer = full_output.split("Response:")[-1].strip() |
|
|
else: |
|
|
answer = full_output |
|
|
|
|
|
stop_words = ["Instruction", "Input", "###", "Response", "User:", "Assistant:", "\n\n"] |
|
|
for sw in stop_words: |
|
|
if sw in answer: |
|
|
answer = answer.split(sw)[0].strip() |
|
|
|
|
|
|
|
|
lines = answer.split('\n') |
|
|
if len(lines) > 0 and prompt.lower() in lines[0].lower(): |
|
|
answer = "\n".join(lines[1:]).strip() |
|
|
return answer |
|
|
except Exception as e: |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
return f"Error: {e}" |
|
|
|
|
|
def build_ui(model_instance): |
|
|
with gr.Blocks(title="MultiModal Dense Transformer - Gradio", css=""" |
|
|
.gradio-container { max-width: 900px; margin: auto; } |
|
|
""") as demo: |
|
|
gr.Markdown("## 多模态在线推理(文本 + 图片)") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
txt = gr.Textbox(label="Prompt (Instruction)", placeholder="请输入指令或问题...", lines=5) |
|
|
img = gr.Image(type="pil", label="(可选) 上传图片(支持多模态)") |
|
|
btn = gr.Button("生成 (Generate)") |
|
|
with gr.Column(scale=2): |
|
|
max_tokens = gr.Slider(label="Max New Tokens", minimum=16, maximum=1024, step=1, value=128) |
|
|
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.01, value=0.7) |
|
|
top_k = gr.Slider(label="Top-k", minimum=0, maximum=200, step=1, value=40) |
|
|
top_p = gr.Slider(label="Top-p", minimum=0.0, maximum=1.0, step=0.01, value=0.9) |
|
|
rep_pen = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.01, value=1.1) |
|
|
status = gr.Textbox(label="Status", value="Ready", interactive=False) |
|
|
output = gr.Textbox(label="Output", lines=12, interactive=False) |
|
|
|
|
|
def gr_generate(prompt, image, max_tokens_v, temp_v, topk_v, topp_v, rep_v): |
|
|
if not prompt or str(prompt).strip() == "": |
|
|
return "", "请输入 Prompt", "" |
|
|
status_msg = "Generating..." |
|
|
|
|
|
out = model_instance.generate_text(prompt=prompt, |
|
|
max_new_tokens=int(max_tokens_v), |
|
|
temperature=float(temp_v), |
|
|
top_k=int(topk_v), |
|
|
top_p=float(topp_v), |
|
|
repetition_penalty=float(rep_v), |
|
|
image=image) |
|
|
return out, "Done", "" |
|
|
|
|
|
btn.click(fn=gr_generate, inputs=[txt, img, max_tokens, temperature, top_k, top_p, rep_pen], outputs=[output, status, gr.State()]) |
|
|
|
|
|
demo.launch(share=True) |
|
|
|
|
|
return demo |
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/posttrain/final_model.pt") |
|
|
parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-7B-Instruct") |
|
|
parser.add_argument("--config", type=str, default=None) |
|
|
parser.add_argument("--port", type=int, default=7860) |
|
|
parser.add_argument("--share", type=lambda x: x.lower() in ("true","1","yes"), default=True) |
|
|
args = parser.parse_args() |
|
|
|
|
|
if not Path(args.checkpoint).exists(): |
|
|
possible = list(Path("checkpoints/pretrain").glob("step_*.pt")) |
|
|
if possible: |
|
|
args.checkpoint = str(possible[-1]) |
|
|
print(f"未找到 final_model.pt,使用最新 checkpoint: {args.checkpoint}") |
|
|
else: |
|
|
raise FileNotFoundError(f"找不到检查点: {args.checkpoint}") |
|
|
|
|
|
global model_instance |
|
|
model_instance = ModelInference(args.checkpoint, args.tokenizer, args.config) |
|
|
|
|
|
demo = build_ui(model_instance) |
|
|
demo.launch(server_port=args.port, share=args.share) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|