szxllm commited on
Commit
4f003b4
·
verified ·
1 Parent(s): 6328772

Update gradio1.py

Browse files
Files changed (1) hide show
  1. gradio1.py +1 -22
gradio1.py CHANGED
@@ -1,13 +1,3 @@
1
- """
2
- Gradio 推理界面 - 多模态 Dense Transformer (适配 Qwen Tokenizer 版)
3
-
4
- 用法:
5
- pip install -r requirements.txt
6
- # requirements.txt 至少包含:
7
- # torch>=1.12, transformers, pillow, gradio
8
- python app_gradio.py --checkpoint /path/to/final_model.pt --tokenizer Qwen/Qwen2.5-7B-Instruct --port 7860 --share False
9
- """
10
-
11
  import os
12
  import argparse
13
  from pathlib import Path
@@ -20,15 +10,11 @@ from transformers import AutoTokenizer
20
 
21
  # UI
22
  import gradio as gr
23
-
24
- # 本项目代码引用(按你的工程结构调整)
25
  from model import MultiModalDenseTransformer
26
  from continual_learning import UnifiedMultiModalPreprocessor
27
 
28
- # 设置国内镜像(如需要)
29
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
30
 
31
- # ---- 与你原来保持一致的图像预处理 ----
32
  from torchvision import transforms
33
  image_transform = transforms.Compose([
34
  transforms.Resize((224, 224)),
@@ -37,7 +23,6 @@ image_transform = transforms.Compose([
37
  std=[0.229, 0.224, 0.225]),
38
  ])
39
 
40
- # -------- ModelInference 类(轻微改写) --------
41
  class ModelInference:
42
  def __init__(self, checkpoint_path: str, tokenizer_name: str, config_path: Optional[str] = None, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
43
  self.device = torch.device(device)
@@ -52,7 +37,6 @@ class ModelInference:
52
  with open(config_path, 'r') as f:
53
  self.config = json.load(f)
54
  else:
55
- # 采用你原始脚本中的默认 config(可按需调整)
56
  self.config = {
57
  'model_dim': 1536,
58
  'vocab_size': len(self.tokenizer),
@@ -77,7 +61,6 @@ class ModelInference:
77
 
78
  print(f"Loading checkpoint from {checkpoint_path}...")
79
  checkpoint = torch.load(checkpoint_path, map_location=self.device)
80
- # 支持 checkpoint 包含 'model_state_dict' 的情况
81
  state_dict = checkpoint.get('model_state_dict', checkpoint) if isinstance(checkpoint, dict) else checkpoint
82
 
83
  new_state_dict = {}
@@ -158,12 +141,11 @@ class ModelInference:
158
  traceback.print_exc()
159
  return f"Error: {e}"
160
 
161
- # -------- Gradio UI 部分 --------
162
  def build_ui(model_instance):
163
  with gr.Blocks(title="MultiModal Dense Transformer - Gradio", css="""
164
  .gradio-container { max-width: 900px; margin: auto; }
165
  """) as demo:
166
- gr.Markdown("## 🚀 多模态在线推理(文本 + 图片)")
167
  with gr.Row():
168
  with gr.Column(scale=3):
169
  txt = gr.Textbox(label="Prompt (Instruction)", placeholder="请输入指令或问题...", lines=5)
@@ -198,7 +180,6 @@ def build_ui(model_instance):
198
 
199
  return demo
200
 
201
- # -------- CLI / main --------
202
  def main():
203
  parser = argparse.ArgumentParser()
204
  parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/posttrain/final_model.pt")
@@ -208,7 +189,6 @@ def main():
208
  parser.add_argument("--share", type=lambda x: x.lower() in ("true","1","yes"), default=True)
209
  args = parser.parse_args()
210
 
211
- # 如果 default 的 final_model 不存在,尝试寻找最近 step
212
  if not Path(args.checkpoint).exists():
213
  possible = list(Path("checkpoints/pretrain").glob("step_*.pt"))
214
  if possible:
@@ -220,7 +200,6 @@ def main():
220
  global model_instance
221
  model_instance = ModelInference(args.checkpoint, args.tokenizer, args.config)
222
 
223
- # 启动 Gradio(使用 share 参数决定是否创建公网链接)
224
  demo = build_ui(model_instance)
225
  demo.launch(server_port=args.port, share=args.share)
226
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import argparse
3
  from pathlib import Path
 
10
 
11
  # UI
12
  import gradio as gr
 
 
13
  from model import MultiModalDenseTransformer
14
  from continual_learning import UnifiedMultiModalPreprocessor
15
 
 
16
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
17
 
 
18
  from torchvision import transforms
19
  image_transform = transforms.Compose([
20
  transforms.Resize((224, 224)),
 
23
  std=[0.229, 0.224, 0.225]),
24
  ])
25
 
 
26
  class ModelInference:
27
  def __init__(self, checkpoint_path: str, tokenizer_name: str, config_path: Optional[str] = None, device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
28
  self.device = torch.device(device)
 
37
  with open(config_path, 'r') as f:
38
  self.config = json.load(f)
39
  else:
 
40
  self.config = {
41
  'model_dim': 1536,
42
  'vocab_size': len(self.tokenizer),
 
61
 
62
  print(f"Loading checkpoint from {checkpoint_path}...")
63
  checkpoint = torch.load(checkpoint_path, map_location=self.device)
 
64
  state_dict = checkpoint.get('model_state_dict', checkpoint) if isinstance(checkpoint, dict) else checkpoint
65
 
66
  new_state_dict = {}
 
141
  traceback.print_exc()
142
  return f"Error: {e}"
143
 
 
144
  def build_ui(model_instance):
145
  with gr.Blocks(title="MultiModal Dense Transformer - Gradio", css="""
146
  .gradio-container { max-width: 900px; margin: auto; }
147
  """) as demo:
148
+ gr.Markdown("## 多模态在线推理(文本 + 图片)")
149
  with gr.Row():
150
  with gr.Column(scale=3):
151
  txt = gr.Textbox(label="Prompt (Instruction)", placeholder="请输入指令或问题...", lines=5)
 
180
 
181
  return demo
182
 
 
183
  def main():
184
  parser = argparse.ArgumentParser()
185
  parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/posttrain/final_model.pt")
 
189
  parser.add_argument("--share", type=lambda x: x.lower() in ("true","1","yes"), default=True)
190
  args = parser.parse_args()
191
 
 
192
  if not Path(args.checkpoint).exists():
193
  possible = list(Path("checkpoints/pretrain").glob("step_*.pt"))
194
  if possible:
 
200
  global model_instance
201
  model_instance = ModelInference(args.checkpoint, args.tokenizer, args.config)
202
 
 
203
  demo = build_ui(model_instance)
204
  demo.launch(server_port=args.port, share=args.share)
205