szxllm commited on
Commit
2e9a238
·
verified ·
1 Parent(s): 68004aa

Update infer.py

Browse files
Files changed (1) hide show
  1. infer.py +306 -263
infer.py CHANGED
@@ -1,150 +1,176 @@
1
  import os
2
- import torch
3
- import torch.nn.functional as F
4
- from flask import Flask, render_template, request, jsonify
5
- from transformers import AutoTokenizer
6
- from PIL import Image
7
- import json
8
- import io
9
- import base64
10
  from pathlib import Path
 
11
  from typing import Optional
12
 
 
 
 
 
 
13
  from model import MultiModalDenseTransformer
14
  from continual_learning import UnifiedMultiModalPreprocessor
15
- from torchvision import transforms
16
-
17
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
18
 
 
19
  image_transform = transforms.Compose([
20
  transforms.Resize((224, 224)),
21
  transforms.ToTensor(),
22
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
 
23
  ])
24
-
25
  class ModelInference:
26
  def __init__(
27
- self,
28
- checkpoint_path: str,
29
- tokenizer_name: str,
30
- config_path: Optional[str] = None,
31
  device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
32
  ):
33
  self.device = torch.device(device)
34
- print(f"Using device: {self.device}")
35
- print(f"Loading tokenizer: {tokenizer_name}...")
36
- try:
37
- self.tokenizer = AutoTokenizer.from_pretrained(
38
- tokenizer_name,
39
- use_fast=True,
40
- trust_remote_code=True
41
- )
42
- if self.tokenizer.pad_token is None:
43
- self.tokenizer.pad_token = self.tokenizer.eos_token
44
- self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
45
- except Exception as e:
46
- print(f"Error loading tokenizer: {e}")
47
- raise e
48
-
49
  if config_path and Path(config_path).exists():
50
  with open(config_path, 'r') as f:
51
  self.config = json.load(f)
52
  else:
53
  self.config = {
54
- 'model_dim': 1536,
55
- 'vocab_size': len(self.tokenizer),
56
  'n_layers': 12,
57
- 'n_heads': 12,
58
- 'n_kv_heads': 4,
59
- 'head_dim': None,
60
- 'max_seq_len': 512,
61
- 'dropout': 0.0,
62
- 'use_moe': False,
63
- 'use_adapter': False,
64
- 'use_lora': False,
65
- 'rope_scaling_type': "yarn"
 
 
66
  }
 
 
 
 
 
 
67
 
68
- # 3. 初始化模型结构
69
- print("Initializing model architecture...")
70
- try:
71
- self.model = MultiModalDenseTransformer(**self.config)
72
- self.preprocessor = UnifiedMultiModalPreprocessor(
73
- model_dim=self.config['model_dim']
74
- )
75
-
76
- # 4. 加载权重
77
- print(f"Loading checkpoint from {checkpoint_path}...")
78
- checkpoint = torch.load(
79
- checkpoint_path,
80
- map_location=self.device,
81
- weights_only=False
82
- )
83
-
84
- if 'model_state_dict' in checkpoint:
85
- print("Found 'model_state_dict' in checkpoint.")
86
- state_dict = checkpoint['model_state_dict']
87
  else:
88
- state_dict = checkpoint
89
-
90
- new_state_dict = {}
91
- for k, v in state_dict.items():
92
- if k.startswith('module.'):
93
- new_state_dict[k[7:]] = v
94
- else:
95
- new_state_dict[k] = v
96
-
97
- missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
98
- if missing:
99
- print(f"Warning: Missing keys: {len(missing)}")
100
- if unexpected:
101
- print(f"Warning: Unexpected keys: {len(unexpected)}")
102
-
103
- self.model.to(self.device)
104
- self.preprocessor.to(self.device)
105
- self.model.eval()
106
-
107
- print("Model loaded successfully!")
108
- print(f"Total parameters: {sum(p.numel() for p in self.model.parameters()) / 1e6:.2f}M")
109
-
110
- except Exception as e:
111
- print(f"Error initializing model: {e}")
112
- raise e
113
-
 
 
 
 
 
 
 
 
 
114
  @torch.no_grad()
115
  def generate_text(
116
- self,
117
- prompt: str,
118
- max_new_tokens: int = 128,
119
- temperature: float = 0.7,
120
- top_k: int = 40,
121
- top_p: float = 0.9,
122
- repetition_penalty: float = 1.1,
123
  image: Optional[Image.Image] = None
124
  ) -> str:
125
- """生成文本"""
126
- inputs = self.tokenizer(prompt, return_tensors="pt")
127
- input_ids = inputs['input_ids'].to(self.device)
128
- input_data = {'segments': []}
129
 
130
- # 处理图像
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  if image is not None:
132
- if image.mode != 'RGB':
133
- image = image.convert('RGB')
134
- image_tensor = image_transform(image).unsqueeze(0).to(self.device)
135
  try:
136
- mod_segments = self.preprocessor.process_batch(image_tensor, 'image')
137
- for seg in mod_segments:
138
- input_data['segments'].append(seg)
 
 
 
 
 
 
 
 
 
139
  except Exception as e:
140
- print(f"Warning: Image processing skipped due to error: {e}")
141
-
142
- input_data['segments'].append({
143
- 'type': 'text',
144
- 'data': input_ids,
145
- 'modality_id': 0
146
- })
147
-
 
 
148
  try:
149
  generated_ids = self.model.generate(
150
  input_data,
@@ -158,174 +184,191 @@ class ModelInference:
158
  pad_token_id=self.tokenizer.pad_token_id
159
  )
160
 
161
- generated_text = self.tokenizer.decode(
162
- generated_ids[0],
163
- skip_special_tokens=True
164
- )
165
- return generated_text
166
 
167
  except Exception as e:
168
- print(f"Generation error: {e}")
169
  import traceback
170
  traceback.print_exc()
171
- return f"Error: {str(e)}"
172
-
173
- model_instance = None
174
- app = Flask(__name__)
175
- app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
176
-
177
- @app.route('/')
178
- def index():
179
- display_config = model_instance.config.copy() if model_instance else {}
180
- return render_template('index.html', config=display_config)
181
 
182
- @app.route('/generate', methods=['POST'])
183
- def generate():
184
- try:
185
- data = request.json
186
- prompt = data.get('prompt', '')
187
- if not prompt.strip():
188
- return jsonify({'error': '请输入提示文本'}), 400
189
 
190
- max_tokens = int(data.get('max_tokens', 100))
191
- temperature = float(data.get('temperature', 0.7))
192
- top_k = int(data.get('top_k', 40))
193
- top_p = float(data.get('top_p', 0.9))
194
- repetition_penalty = float(data.get('repetition_penalty', 1.1))
195
-
196
- image = None
197
- if 'image' in data and data['image']:
198
- try:
199
- image_data = base64.b64decode(data['image'].split(',')[1])
200
- image = Image.open(io.BytesIO(image_data))
201
- except Exception as e:
202
- print(f"Image load error: {e}")
203
-
204
- output = model_instance.generate_text(
205
- prompt, max_tokens, temperature, top_k, top_p, repetition_penalty, image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  )
207
- return jsonify({'output': output})
208
-
209
- except Exception as e:
210
- return jsonify({'error': str(e)}), 500
211
 
212
- def create_html_template():
213
- """写入HTML模板"""
214
- html_content = '''
215
- <!DOCTYPE html>
216
- <html lang="zh-CN">
217
- <head>
218
- <meta charset="UTF-8">
219
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
220
- <title>Model Inference</title>
221
- <style>
222
- body { font-family: sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; background: #f0f2f5; }
223
- .container { background: white; padding: 30px; border-radius: 12px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); }
224
- h1 { color: #1a73e8; text-align: center; }
225
- textarea { width: 100%; padding: 10px; border: 1px solid #ddd; border-radius: 8px; margin: 10px 0; min-height: 100px; }
226
- .controls { display: grid; grid-template-columns: 1fr 1fr; gap: 15px; margin: 20px 0; background: #f8f9fa; padding: 15px; border-radius: 8px; }
227
- button { background: #1a73e8; color: white; border: none; padding: 12px 24px; border-radius: 6px; cursor: pointer; width: 100%; font-size: 16px; transition: background 0.3s; }
228
- button:hover { background: #1557b0; }
229
- button:disabled { background: #ccc; }
230
- #output { margin-top: 20px; padding: 20px; background: #f8f9fa; border-radius: 8px; white-space: pre-wrap; min-height: 100px; border: 1px solid #e0e0e0; }
231
- .loading { color: #666; font-style: italic; }
232
- </style>
233
- </head>
234
- <body>
235
- <div class="container">
236
- <h1> 模型在线推理</h1>
237
-
238
- <div>
239
- <label><strong>提示词 (Prompt):</strong></label>
240
- <textarea id="prompt" placeholder="请输入你的问题..."></textarea>
241
- </div>
242
 
243
- <div class="controls">
244
- <div>
245
- <label>Max Tokens: <span id="maxTokensVal">128</span></label>
246
- <input type="range" id="maxTokens" min="32" max="1024" value="128" style="width:100%" oninput="document.getElementById('maxTokensVal').innerText=this.value">
247
- </div>
248
- <div>
249
- <label>Temperature: <span id="tempVal">0.7</span></label>
250
- <input type="range" id="temperature" min="0.1" max="1.5" step="0.1" value="0.7" style="width:100%" oninput="document.getElementById('tempVal').innerText=this.value">
251
- </div>
252
- </div>
253
 
254
- <button id="btn" onclick="generate()">生成 (Generate)</button>
255
-
256
- <div id="output">结果将显示在这里...</div>
257
- </div>
258
-
259
- <script>
260
- async function generate() {
261
- const prompt = document.getElementById('prompt').value;
262
- if(!prompt) return alert("请输入内容");
263
-
264
- const btn = document.getElementById('btn');
265
- const out = document.getElementById('output');
266
-
267
- btn.disabled = true;
268
- btn.innerText = "生成中...";
269
- out.innerHTML = '<div class="loading">正在思考中...</div>';
270
-
271
- try {
272
- const res = await fetch('/generate', {
273
- method: 'POST',
274
- headers: {'Content-Type': 'application/json'},
275
- body: JSON.stringify({
276
- prompt: prompt,
277
- max_tokens: parseInt(document.getElementById('maxTokens').value),
278
- temperature: parseFloat(document.getElementById('temperature').value)
279
- })
280
- });
281
- const data = await res.json();
282
- if(data.error) out.innerText = "Error: " + data.error;
283
- else out.innerText = data.output;
284
- } catch(e) {
285
- out.innerText = "请求失败: " + e;
286
- } finally {
287
- btn.disabled = false;
288
- btn.innerText = "生成 (Generate)";
289
- }
290
- }
291
- </script>
292
- </body>
293
- </html>
294
- '''
295
-
296
- Path('templates').mkdir(exist_ok=True)
297
- with open('templates/index.html', 'w', encoding='utf-8') as f:
298
- f.write(html_content)
299
 
300
  def main():
301
- import argparse
302
- parser = argparse.ArgumentParser()
303
- # 默认指向 pretrain 保存的 checkpoint 路径
304
- parser.add_argument("--checkpoint", type=str, default="/root/multimodal/checkpoints/pretrain_fixed/step_10000.pt")
305
- parser.add_argument("--tokenizer", type=str, default="Qwen/Qwen2.5-7B-Instruct")
306
- parser.add_argument("--port", type=int, default=5001)
307
- parser.add_argument("--host", type=str, default="0.0.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  args = parser.parse_args()
309
-
310
  if not Path(args.checkpoint).exists():
311
- # 尝试找最近的 step checkpoint
312
- steps = list(Path("checkpoints/pretrain").glob("step_*.pt"))
313
- if steps:
314
- print(f"未找到 final_model.pt,尝试使用最新的 checkpoint: {steps[-1]}")
315
- args.checkpoint = str(steps[-1])
316
- else:
317
- print(f"错误: 找不到检查点文件: {args.checkpoint}")
318
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
- create_html_template()
321
 
322
  global model_instance
323
- model_instance = ModelInference(args.checkpoint, args.tokenizer)
 
 
 
 
 
324
 
325
- print(f"\n服务已启动: http://{args.host}:{args.port}")
326
- app.run(host=args.host, port=args.port,
327
- debug=True, # 开启调试模式
328
- use_reloader=False)
 
 
329
 
330
  if __name__ == "__main__":
331
  main()
 
1
  import os
2
+ import argparse
 
 
 
 
 
 
 
3
  from pathlib import Path
4
+ import json
5
  from typing import Optional
6
 
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import AutoTokenizer
10
+
11
+ import gradio as gr
12
  from model import MultiModalDenseTransformer
13
  from continual_learning import UnifiedMultiModalPreprocessor
 
 
14
  os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
15
 
16
+ from torchvision import transforms
17
  image_transform = transforms.Compose([
18
  transforms.Resize((224, 224)),
19
  transforms.ToTensor(),
20
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
21
+ std=[0.229, 0.224, 0.225]),
22
  ])
 
23
  class ModelInference:
24
  def __init__(
25
+ self,
26
+ checkpoint_path: str,
27
+ tokenizer_name: str,
28
+ config_path: Optional[str] = None,
29
  device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
30
  ):
31
  self.device = torch.device(device)
32
+ self.tokenizer = AutoTokenizer.from_pretrained(
33
+ tokenizer_name,
34
+ use_fast=True,
35
+ trust_remote_code=True
36
+ )
37
+ if self.tokenizer.pad_token is None:
38
+ self.tokenizer.pad_token = self.tokenizer.eos_token
39
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
 
 
 
 
 
 
 
40
  if config_path and Path(config_path).exists():
41
  with open(config_path, 'r') as f:
42
  self.config = json.load(f)
43
  else:
44
  self.config = {
45
+ 'model_dim': 1536,
46
+ 'vocab_size': len(self.tokenizer),
47
  'n_layers': 12,
48
+ 'n_heads': 12,
49
+ 'n_kv_heads': 4,
50
+ 'head_dim': None,
51
+ 'max_seq_len': 512,
52
+ 'dropout': 0.0,
53
+ 'use_moe': False,
54
+ 'use_adapter': False,
55
+ 'use_lora': False,
56
+ 'rope_scaling_type': "yarn",
57
+ 'use_multimodal_fusion': False,
58
+ 'use_contrastive': False
59
  }
60
+
61
+ self.model = MultiModalDenseTransformer(**self.config)
62
+ self.preprocessor = UnifiedMultiModalPreprocessor(model_dim=self.config['model_dim'])
63
+
64
+ print(f"Loading checkpoint from {checkpoint_path}...")
65
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
66
 
67
+ state_dict = None
68
+
69
+ if 'actor_state_dict' in checkpoint:
70
+ print("Detected GRPO checkpoint format (actor_state_dict)")
71
+ state_dict = checkpoint['actor_state_dict']
72
+ elif 'model_state_dict' in checkpoint:
73
+ print("Detected Standard/SFT checkpoint format (model_state_dict)")
74
+ state_dict = checkpoint['model_state_dict']
75
+ else:
76
+ print("Detected raw state dict format")
77
+ state_dict = checkpoint
78
+
79
+ new_state_dict = {}
80
+ for k, v in state_dict.items():
81
+ if k.startswith('module.'):
82
+ new_state_dict[k[7:]] = v
 
 
 
83
  else:
84
+ new_state_dict[k] = v
85
+
86
+ missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
87
+ if missing:
88
+ print(f"Warning: Missing keys: {len(missing)}")
89
+ if len(missing) <= 10:
90
+ print(f"Missing keys: {missing}")
91
+ if unexpected:
92
+ print(f"Warning: Unexpected keys: {len(unexpected)}")
93
+ if len(unexpected) <= 10:
94
+ print(f"Unexpected keys: {unexpected}")
95
+
96
+ self.model.to(self.device)
97
+ self.preprocessor.to(self.device)
98
+ self.model.eval()
99
+
100
+
101
+ def _build_position_ids(self, attention_mask: torch.Tensor) -> torch.Tensor:
102
+
103
+ batch_size, seq_len = attention_mask.shape
104
+ position_ids = torch.zeros((batch_size, seq_len), dtype=torch.long, device=self.device)
105
+
106
+ for i in range(batch_size):
107
+ non_pad_positions = (attention_mask[i] == 1).nonzero(as_tuple=True)[0]
108
+ if len(non_pad_positions) > 0:
109
+ start_pos = non_pad_positions[0].item()
110
+ valid_len = len(non_pad_positions)
111
+ # 从 0 开始编号有效 token 的位置
112
+ position_ids[i, start_pos:start_pos + valid_len] = torch.arange(
113
+ valid_len,
114
+ device=self.device
115
+ )
116
+
117
+ return position_ids
118
+
119
  @torch.no_grad()
120
  def generate_text(
121
+ self,
122
+ prompt: str,
123
+ max_new_tokens: int = 128,
124
+ temperature: float = 0.7,
125
+ top_k: int = 40,
126
+ top_p: float = 0.9,
127
+ repetition_penalty: float = 1.1,
128
  image: Optional[Image.Image] = None
129
  ) -> str:
130
+ formatted_prompt = f"user: {prompt}\nassistant:\n<think>\n"
 
 
 
131
 
132
+ inputs = self.tokenizer(
133
+ formatted_prompt,
134
+ return_tensors="pt",
135
+ padding=False
136
+ )
137
+ input_ids = inputs['input_ids'].to(self.device)
138
+ attention_mask = inputs['attention_mask'].to(self.device)
139
+
140
+ segments = []
141
+
142
+ segments.append({
143
+ 'type': 'text',
144
+ 'data': input_ids,
145
+ 'modality_id': 0
146
+ })
147
+
148
+ has_image = False
149
  if image is not None:
 
 
 
150
  try:
151
+ if image.mode != 'RGB':
152
+ image = image.convert('RGB')
153
+ image_tensor = image_transform(image).unsqueeze(0).to(self.device)
154
+
155
+ segments.append({
156
+ 'type': 'image',
157
+ 'data': image_tensor,
158
+ 'modality_id': 1
159
+ })
160
+ has_image = True
161
+ print("Image added to input")
162
+
163
  except Exception as e:
164
+ print(f"Warning: Image processing error: {e}")
165
+
166
+ position_ids = self._build_position_ids(attention_mask)
167
+ input_data = {
168
+ 'segments': segments,
169
+ }
170
+ input_data['attention_mask'] = attention_mask
171
+ if not has_image:
172
+ input_data['position_ids'] = position_ids
173
+
174
  try:
175
  generated_ids = self.model.generate(
176
  input_data,
 
184
  pad_token_id=self.tokenizer.pad_token_id
185
  )
186
 
187
+ output_text = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
188
+ return output_text.strip()
 
 
 
189
 
190
  except Exception as e:
 
191
  import traceback
192
  traceback.print_exc()
193
+ return f"Error during generation: {str(e)}"
 
 
 
 
 
 
 
 
 
194
 
195
+ def build_ui(model_instance):
196
+ with gr.Blocks(title="MultiModal Dense Transformer - Gradio", css="""
197
+ .gradio-container { max-width: 900px; margin: auto; }
198
+ """) as demo:
199
+ gr.Markdown("## 在线推理(文本)")
 
 
200
 
201
+ with gr.Row():
202
+ with gr.Column(scale=3):
203
+ txt = gr.Textbox(
204
+ label="Prompt (Instruction)",
205
+ placeholder="请输入指令或问题...",
206
+ lines=5
207
+ )
208
+ img = gr.Image(type="pil", label="(可选) 上传图片(支持多模态)")
209
+ btn = gr.Button("生成 (Generate)", variant="primary")
210
+
211
+ with gr.Column(scale=2):
212
+ max_tokens = gr.Slider(
213
+ label="Max New Tokens",
214
+ minimum=16,
215
+ maximum=1024,
216
+ step=1,
217
+ value=128
218
+ )
219
+ temperature = gr.Slider(
220
+ label="Temperature",
221
+ minimum=0.1,
222
+ maximum=1.5,
223
+ step=0.01,
224
+ value=0.7
225
+ )
226
+ top_k = gr.Slider(
227
+ label="Top-k",
228
+ minimum=0,
229
+ maximum=200,
230
+ step=1,
231
+ value=40
232
+ )
233
+ top_p = gr.Slider(
234
+ label="Top-p",
235
+ minimum=0.0,
236
+ maximum=1.0,
237
+ step=0.01,
238
+ value=0.9
239
+ )
240
+ rep_pen = gr.Slider(
241
+ label="Repetition Penalty",
242
+ minimum=0.5,
243
+ maximum=2.0,
244
+ step=0.01,
245
+ value=1.1
246
+ )
247
+ status = gr.Textbox(
248
+ label="Status",
249
+ value="Ready",
250
+ interactive=False
251
+ )
252
+
253
+ output = gr.Textbox(label="Output", lines=12, interactive=False)
254
+ gr.Examples(
255
+ examples=[
256
+ ["请解释什么是深度学习", None],
257
+ ["计算 123 + 456 等于多少?", None],
258
+ ["写一首关于春天的诗", None],
259
+ ],
260
+ inputs=[txt, img],
261
  )
 
 
 
 
262
 
263
+ def gr_generate(prompt, image, max_tokens_v, temp_v, topk_v, topp_v, rep_v):
264
+ if not prompt or str(prompt).strip() == "":
265
+ return "", " 请输入 Prompt"
266
+
267
+ try:
268
+ status_msg = " Generating..."
269
+ # 调用模型生成
270
+ out = model_instance.generate_text(
271
+ prompt=prompt,
272
+ max_new_tokens=int(max_tokens_v),
273
+ temperature=float(temp_v),
274
+ top_k=int(topk_v),
275
+ top_p=float(topp_v),
276
+ repetition_penalty=float(rep_v),
277
+ image=image
278
+ )
279
+ return out, " Done"
280
+ except Exception as e:
281
+ return f"Error: {str(e)}", " Error"
 
 
 
 
 
 
 
 
 
 
 
282
 
283
+ btn.click(
284
+ fn=gr_generate,
285
+ inputs=[txt, img, max_tokens, temperature, top_k, top_p, rep_pen],
286
+ outputs=[output, status]
287
+ )
 
 
 
 
 
288
 
289
+ return demo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
  def main():
292
+ parser = argparse.ArgumentParser(
293
+ description="Gradio inference interface for MultiModal Dense Transformer"
294
+ )
295
+ parser.add_argument(
296
+ "--checkpoint",
297
+ type=str,
298
+ default="/root/checkpoints/dcpo_posttrain_round3/step_15600.pt",
299
+ help="Path to model checkpoint"
300
+ )
301
+ parser.add_argument(
302
+ "--tokenizer",
303
+ type=str,
304
+ default="Qwen/Qwen2.5-7B-Instruct",
305
+ help="Tokenizer name or path"
306
+ )
307
+ parser.add_argument(
308
+ "--config",
309
+ type=str,
310
+ default=None,
311
+ help="Path to model config JSON (optional)"
312
+ )
313
+ parser.add_argument(
314
+ "--port",
315
+ type=int,
316
+ default=5001,
317
+ help="Port to run Gradio server"
318
+ )
319
+ parser.add_argument(
320
+ "--share",
321
+ type=lambda x: x.lower() in ("true","1","yes"),
322
+ default=True,
323
+ help="Create public link (True/False)"
324
+ )
325
  args = parser.parse_args()
326
+
327
  if not Path(args.checkpoint).exists():
328
+ print(f" Checkpoint not found: {args.checkpoint}")
329
+
330
+ possible_dirs = [
331
+ Path("/root/checkpoints/posttrain/grpo"),
332
+ Path("/root/checkpoints/dcpo_training"),
333
+ Path("/root/checkpoints/r1_zero_reproduction"),
334
+ ]
335
+
336
+ for checkpoint_dir in possible_dirs:
337
+ if checkpoint_dir.exists():
338
+ grpo_files = sorted(
339
+ [p for p in checkpoint_dir.glob("grpo_iter_*.pt")],
340
+ key=lambda p: int(p.stem.split('_')[-1]) if p.stem.split('_')[-1].isdigit() else 0
341
+ )
342
+
343
+ step_files = sorted(
344
+ [p for p in checkpoint_dir.glob("step_*.pt")],
345
+ key=lambda p: int(p.stem.split('_')[-1]) if p.stem.split('_')[-1].isdigit() else 0
346
+ )
347
+
348
+ candidates = grpo_files + step_files
349
+ if candidates:
350
+ args.checkpoint = str(candidates[-1])
351
+ print(f" Using latest checkpoint: {args.checkpoint}")
352
+ break
353
+
354
+ if not Path(args.checkpoint).exists():
355
+ raise FileNotFoundError(f"找不到可用的检查点文件")
356
 
 
357
 
358
  global model_instance
359
+ model_instance = ModelInference(
360
+ args.checkpoint,
361
+ args.tokenizer,
362
+ args.config
363
+ )
364
+
365
 
366
+ demo = build_ui(model_instance)
367
+ demo.launch(
368
+ server_port=args.port,
369
+ share=args.share,
370
+ server_name="0.0.0.0" # 允许外部访问
371
+ )
372
 
373
  if __name__ == "__main__":
374
  main()