MultiModal / gradio1.py
szxllm's picture
Update gradio1.py
4f003b4 verified
import os
import argparse
from pathlib import Path
import json
from typing import Optional
import torch
from PIL import Image
from transformers import AutoTokenizer
# UI
import gradio as gr
from model import MultiModalDenseTransformer
from continual_learning import UnifiedMultiModalPreprocessor
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
from torchvision import transforms
image_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
class ModelInference:
def __init__(self, checkpoint_path: str, tokenizer_name: str, config_path: Optional[str] = None, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
self.device = torch.device(device)
print(f"Using device: {self.device}")
print(f"Loading tokenizer: {tokenizer_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True, trust_remote_code=True)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
if config_path and Path(config_path).exists():
with open(config_path, 'r') as f:
self.config = json.load(f)
else:
self.config = {
'model_dim': 1536,
'vocab_size': len(self.tokenizer),
'n_layers': 12,
'n_heads': 12,
'n_kv_heads': 4,
'head_dim': None,
'max_seq_len': 512,
'dropout': 0.0,
'use_moe': False,
'use_adapter': False,
'use_lora': False,
'rope_scaling_type': "yarn",
'use_multimodal_fusion': False,
'use_contrastive': False
}
# init model + preprocessor
print("Initializing model architecture...")
self.model = MultiModalDenseTransformer(**self.config)
self.preprocessor = UnifiedMultiModalPreprocessor(model_dim=self.config['model_dim'])
print(f"Loading checkpoint from {checkpoint_path}...")
checkpoint = torch.load(checkpoint_path, map_location=self.device)
state_dict = checkpoint.get('model_state_dict', checkpoint) if isinstance(checkpoint, dict) else checkpoint
new_state_dict = {}
for k, v in state_dict.items():
if k.startswith('module.'):
new_state_dict[k[7:]] = v
else:
new_state_dict[k] = v
missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
if missing:
print(f"Warning: Missing keys: {len(missing)}")
if unexpected:
print(f"Warning: Unexpected keys: {len(unexpected)}")
self.model.to(self.device)
self.preprocessor.to(self.device)
self.model.eval()
print("Model loaded successfully!")
print(f"Total parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M")
@torch.no_grad()
def generate_text(self, prompt: str, max_new_tokens: int = 128, temperature: float = 0.7, top_k: int = 10, top_p: float = 0.9, repetition_penalty: float = 1.2, image: Optional[Image.Image] = None) -> str:
formatted_prompt = f"Instruction: {prompt}\nResponse:"
inputs = self.tokenizer(formatted_prompt, return_tensors="pt")
input_ids = inputs['input_ids'].to(self.device)
input_data = {'segments': []}
if image is not None:
try:
if image.mode != 'RGB':
image = image.convert('RGB')
image_tensor = image_transform(image).unsqueeze(0).to(self.device)
mod_segments = self.preprocessor.process_batch(image_tensor, 'image')
for seg in mod_segments:
input_data['segments'].append(seg)
except Exception as e:
print(f"Warning: Image processing skipped due to error: {e}")
input_data['segments'].append({
'type': 'text',
'data': input_ids,
'modality_id': 0
})
try:
generated_ids = self.model.generate(
input_data,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=True,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id
)
full_output = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
# 提取 Response 后的部分并做 stop 处理
if "Response:" in full_output:
answer = full_output.split("Response:")[-1].strip()
else:
answer = full_output
stop_words = ["Instruction", "Input", "###", "Response", "User:", "Assistant:", "\n\n"]
for sw in stop_words:
if sw in answer:
answer = answer.split(sw)[0].strip()
# 去掉可能的 echo
lines = answer.split('\n')
if len(lines) > 0 and prompt.lower() in lines[0].lower():
answer = "\n".join(lines[1:]).strip()
return answer
except Exception as e:
import traceback
traceback.print_exc()
return f"Error: {e}"
def build_ui(model_instance):
with gr.Blocks(title="MultiModal Dense Transformer - Gradio", css="""
.gradio-container { max-width: 900px; margin: auto; }
""") as demo:
gr.Markdown("## 多模态在线推理(文本 + 图片)")
with gr.Row():
with gr.Column(scale=3):
txt = gr.Textbox(label="Prompt (Instruction)", placeholder="请输入指令或问题...", lines=5)
img = gr.Image(type="pil", label="(可选) 上传图片(支持多模态)")
btn = gr.Button("生成 (Generate)")
with gr.Column(scale=2):
max_tokens = gr.Slider(label="Max New Tokens", minimum=16, maximum=1024, step=1, value=128)
temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.5, step=0.01, value=0.7)
top_k = gr.Slider(label="Top-k", minimum=0, maximum=200, step=1, value=40)
top_p = gr.Slider(label="Top-p", minimum=0.0, maximum=1.0, step=0.01, value=0.9)
rep_pen = gr.Slider(label="Repetition Penalty", minimum=0.5, maximum=2.0, step=0.01, value=1.1)
status = gr.Textbox(label="Status", value="Ready", interactive=False)
output = gr.Textbox(label="Output", lines=12, interactive=False)
def gr_generate(prompt, image, max_tokens_v, temp_v, topk_v, topp_v, rep_v):
if not prompt or str(prompt).strip() == "":
return "", "请输入 Prompt", ""
status_msg = "Generating..."
# call model
out = model_instance.generate_text(prompt=prompt,
max_new_tokens=int(max_tokens_v),
temperature=float(temp_v),
top_k=int(topk_v),
top_p=float(topp_v),
repetition_penalty=float(rep_v),
image=image)
return out, "Done", ""
btn.click(fn=gr_generate, inputs=[txt, img, max_tokens, temperature, top_k, top_p, rep_pen], outputs=[output, status, gr.State()])
demo.launch(share=True)
return demo
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/posttrain/final_model.pt")
parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-7B-Instruct")
parser.add_argument("--config", type=str, default=None)
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--share", type=lambda x: x.lower() in ("true","1","yes"), default=True)
args = parser.parse_args()
if not Path(args.checkpoint).exists():
possible = list(Path("checkpoints/pretrain").glob("step_*.pt"))
if possible:
args.checkpoint = str(possible[-1])
print(f"未找到 final_model.pt,使用最新 checkpoint: {args.checkpoint}")
else:
raise FileNotFoundError(f"找不到检查点: {args.checkpoint}")
global model_instance
model_instance = ModelInference(args.checkpoint, args.tokenizer, args.config)
demo = build_ui(model_instance)
demo.launch(server_port=args.port, share=args.share)
if __name__ == "__main__":
main()