gudo7208 commited on
Commit
14230a6
·
verified ·
1 Parent(s): e9a3c74

Delete files batch_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. batch_inference.py +0 -147
batch_inference.py DELETED
@@ -1,147 +0,0 @@
1
- """CAD-Coder Batch Inference Script using vLLM"""
2
- import argparse
3
- import json
4
- import os
5
- import re
6
- import csv
7
- from tqdm import tqdm
8
- from vllm import LLM, SamplingParams
9
-
10
-
11
- def parse_args():
12
- parser = argparse.ArgumentParser(description='CAD-Coder Batch Inference')
13
- parser.add_argument('--model_path', type=str, default="gudo7208/CAD-Coder",
14
- help='Model path or HuggingFace Hub name')
15
- parser.add_argument('--data_path', type=str, required=True,
16
- help='Test data path (JSON or JSONL format)')
17
- parser.add_argument('--output_dir', type=str, default="./output",
18
- help='Output directory')
19
- parser.add_argument('--tensor_parallel_size', type=int, default=1,
20
- help='Number of GPUs for tensor parallelism')
21
- parser.add_argument('--temperature', type=float, default=0.7,
22
- help='Sampling temperature')
23
- parser.add_argument('--top_p', type=float, default=0.7,
24
- help='Top-p sampling')
25
- parser.add_argument('--max_tokens', type=int, default=4096,
26
- help='Maximum tokens to generate')
27
- return parser.parse_args()
28
-
29
-
30
- def load_data(data_path):
31
- """Load test data from JSON or JSONL format"""
32
- if data_path.endswith('.json'):
33
- with open(data_path, 'r', encoding='utf-8') as f:
34
- return json.load(f)
35
- elif data_path.endswith('.jsonl'):
36
- data = []
37
- with open(data_path, 'r', encoding='utf-8') as f:
38
- for line in f:
39
- data.append(json.loads(line))
40
- return data
41
- else:
42
- raise ValueError(f"Unsupported file format: {data_path}")
43
-
44
-
45
- def build_prompt(user_content):
46
- """Build chat format prompt for Qwen model"""
47
- return (f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
48
- f"<|im_start|>user\n{user_content}<|im_end|>\n"
49
- f"<|im_start|>assistant\n")
50
-
51
-
52
- def extract_code(response):
53
- """Extract Python code from response, supporting multiple formats"""
54
- # Try to extract from \boxed{```python ... ```}
55
- boxed_match = re.search(r'\\boxed\{```python\n(.*?)```\}', response, re.DOTALL)
56
- if boxed_match:
57
- return boxed_match.group(1).strip()
58
-
59
- # Try to extract from ```python ... ```
60
- code_match = re.search(r'```python\n(.*?)```', response, re.DOTALL)
61
- if code_match:
62
- return code_match.group(1).strip()
63
-
64
- # Return raw response if no code block found
65
- return response
66
-
67
-
68
- def main():
69
- args = parse_args()
70
-
71
- # Initialize vLLM model
72
- print(f"Loading model from {args.model_path}...")
73
- llm = LLM(
74
- model=args.model_path,
75
- tensor_parallel_size=args.tensor_parallel_size,
76
- trust_remote_code=True,
77
- dtype="float16",
78
- gpu_memory_utilization=0.9
79
- )
80
-
81
- # Set sampling parameters
82
- sampling_params = SamplingParams(
83
- temperature=args.temperature,
84
- top_p=args.top_p,
85
- max_tokens=args.max_tokens
86
- )
87
-
88
- # Load test data
89
- print(f"Loading data from {args.data_path}...")
90
- test_data = load_data(args.data_path)
91
- # test_data = test_data[:50] # 测试时只取50条,正式运行时注释此行
92
- print(f"Loaded {len(test_data)} samples")
93
-
94
- # Prepare prompts and file names
95
- prompts = []
96
- file_names = []
97
-
98
- for item in test_data:
99
- # Extract file name from model_path
100
- model_path = item.get('model_path', f"sample_{len(file_names)}")
101
- file_name = os.path.basename(model_path).replace('.pth', '')
102
- file_names.append(file_name)
103
-
104
- # Extract user content and build prompt
105
- user_content = item['messages'][0]['content']
106
- prompts.append(build_prompt(user_content))
107
-
108
- # Generate with vLLM
109
- print("Generating...")
110
- outputs = llm.generate(prompts, sampling_params)
111
-
112
- # Create output directory
113
- os.makedirs(args.output_dir, exist_ok=True)
114
-
115
- # Prepare CSV data
116
- csv_data = []
117
-
118
- # Save results
119
- for file_name, output in tqdm(zip(file_names, outputs), total=len(outputs), desc="Saving"):
120
- response = output.outputs[0].text
121
- python_code = extract_code(response)
122
-
123
- # Save Python code
124
- code_path = os.path.join(args.output_dir, f"{file_name}.py")
125
- with open(code_path, "w", encoding="utf-8") as f:
126
- f.write(python_code)
127
-
128
- # Save full response
129
- txt_path = os.path.join(args.output_dir, f"{file_name}.txt")
130
- with open(txt_path, "w", encoding="utf-8") as f:
131
- f.write(response)
132
-
133
- # Add to CSV data
134
- csv_data.append([file_name, response, python_code])
135
-
136
- # Save CSV summary
137
- csv_path = os.path.join(args.output_dir, "results.csv")
138
- with open(csv_path, "w", encoding="utf-8", newline="") as f:
139
- writer = csv.writer(f)
140
- writer.writerow(["id", "responses", "code"])
141
- writer.writerows(csv_data)
142
-
143
- print(f"Results saved to {args.output_dir}")
144
-
145
-
146
- if __name__ == "__main__":
147
- main()