Delete files batch_inference.py with huggingface_hub
Browse files- batch_inference.py +0 -147
batch_inference.py
DELETED
|
@@ -1,147 +0,0 @@
|
|
| 1 |
-
"""CAD-Coder Batch Inference Script using vLLM"""
|
| 2 |
-
import argparse
|
| 3 |
-
import json
|
| 4 |
-
import os
|
| 5 |
-
import re
|
| 6 |
-
import csv
|
| 7 |
-
from tqdm import tqdm
|
| 8 |
-
from vllm import LLM, SamplingParams
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def parse_args():
|
| 12 |
-
parser = argparse.ArgumentParser(description='CAD-Coder Batch Inference')
|
| 13 |
-
parser.add_argument('--model_path', type=str, default="gudo7208/CAD-Coder",
|
| 14 |
-
help='Model path or HuggingFace Hub name')
|
| 15 |
-
parser.add_argument('--data_path', type=str, required=True,
|
| 16 |
-
help='Test data path (JSON or JSONL format)')
|
| 17 |
-
parser.add_argument('--output_dir', type=str, default="./output",
|
| 18 |
-
help='Output directory')
|
| 19 |
-
parser.add_argument('--tensor_parallel_size', type=int, default=1,
|
| 20 |
-
help='Number of GPUs for tensor parallelism')
|
| 21 |
-
parser.add_argument('--temperature', type=float, default=0.7,
|
| 22 |
-
help='Sampling temperature')
|
| 23 |
-
parser.add_argument('--top_p', type=float, default=0.7,
|
| 24 |
-
help='Top-p sampling')
|
| 25 |
-
parser.add_argument('--max_tokens', type=int, default=4096,
|
| 26 |
-
help='Maximum tokens to generate')
|
| 27 |
-
return parser.parse_args()
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def load_data(data_path):
|
| 31 |
-
"""Load test data from JSON or JSONL format"""
|
| 32 |
-
if data_path.endswith('.json'):
|
| 33 |
-
with open(data_path, 'r', encoding='utf-8') as f:
|
| 34 |
-
return json.load(f)
|
| 35 |
-
elif data_path.endswith('.jsonl'):
|
| 36 |
-
data = []
|
| 37 |
-
with open(data_path, 'r', encoding='utf-8') as f:
|
| 38 |
-
for line in f:
|
| 39 |
-
data.append(json.loads(line))
|
| 40 |
-
return data
|
| 41 |
-
else:
|
| 42 |
-
raise ValueError(f"Unsupported file format: {data_path}")
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def build_prompt(user_content):
|
| 46 |
-
"""Build chat format prompt for Qwen model"""
|
| 47 |
-
return (f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
| 48 |
-
f"<|im_start|>user\n{user_content}<|im_end|>\n"
|
| 49 |
-
f"<|im_start|>assistant\n")
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def extract_code(response):
|
| 53 |
-
"""Extract Python code from response, supporting multiple formats"""
|
| 54 |
-
# Try to extract from \boxed{```python ... ```}
|
| 55 |
-
boxed_match = re.search(r'\\boxed\{```python\n(.*?)```\}', response, re.DOTALL)
|
| 56 |
-
if boxed_match:
|
| 57 |
-
return boxed_match.group(1).strip()
|
| 58 |
-
|
| 59 |
-
# Try to extract from ```python ... ```
|
| 60 |
-
code_match = re.search(r'```python\n(.*?)```', response, re.DOTALL)
|
| 61 |
-
if code_match:
|
| 62 |
-
return code_match.group(1).strip()
|
| 63 |
-
|
| 64 |
-
# Return raw response if no code block found
|
| 65 |
-
return response
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def main():
|
| 69 |
-
args = parse_args()
|
| 70 |
-
|
| 71 |
-
# Initialize vLLM model
|
| 72 |
-
print(f"Loading model from {args.model_path}...")
|
| 73 |
-
llm = LLM(
|
| 74 |
-
model=args.model_path,
|
| 75 |
-
tensor_parallel_size=args.tensor_parallel_size,
|
| 76 |
-
trust_remote_code=True,
|
| 77 |
-
dtype="float16",
|
| 78 |
-
gpu_memory_utilization=0.9
|
| 79 |
-
)
|
| 80 |
-
|
| 81 |
-
# Set sampling parameters
|
| 82 |
-
sampling_params = SamplingParams(
|
| 83 |
-
temperature=args.temperature,
|
| 84 |
-
top_p=args.top_p,
|
| 85 |
-
max_tokens=args.max_tokens
|
| 86 |
-
)
|
| 87 |
-
|
| 88 |
-
# Load test data
|
| 89 |
-
print(f"Loading data from {args.data_path}...")
|
| 90 |
-
test_data = load_data(args.data_path)
|
| 91 |
-
# test_data = test_data[:50] # 测试时只取50条,正式运行时注释此行
|
| 92 |
-
print(f"Loaded {len(test_data)} samples")
|
| 93 |
-
|
| 94 |
-
# Prepare prompts and file names
|
| 95 |
-
prompts = []
|
| 96 |
-
file_names = []
|
| 97 |
-
|
| 98 |
-
for item in test_data:
|
| 99 |
-
# Extract file name from model_path
|
| 100 |
-
model_path = item.get('model_path', f"sample_{len(file_names)}")
|
| 101 |
-
file_name = os.path.basename(model_path).replace('.pth', '')
|
| 102 |
-
file_names.append(file_name)
|
| 103 |
-
|
| 104 |
-
# Extract user content and build prompt
|
| 105 |
-
user_content = item['messages'][0]['content']
|
| 106 |
-
prompts.append(build_prompt(user_content))
|
| 107 |
-
|
| 108 |
-
# Generate with vLLM
|
| 109 |
-
print("Generating...")
|
| 110 |
-
outputs = llm.generate(prompts, sampling_params)
|
| 111 |
-
|
| 112 |
-
# Create output directory
|
| 113 |
-
os.makedirs(args.output_dir, exist_ok=True)
|
| 114 |
-
|
| 115 |
-
# Prepare CSV data
|
| 116 |
-
csv_data = []
|
| 117 |
-
|
| 118 |
-
# Save results
|
| 119 |
-
for file_name, output in tqdm(zip(file_names, outputs), total=len(outputs), desc="Saving"):
|
| 120 |
-
response = output.outputs[0].text
|
| 121 |
-
python_code = extract_code(response)
|
| 122 |
-
|
| 123 |
-
# Save Python code
|
| 124 |
-
code_path = os.path.join(args.output_dir, f"{file_name}.py")
|
| 125 |
-
with open(code_path, "w", encoding="utf-8") as f:
|
| 126 |
-
f.write(python_code)
|
| 127 |
-
|
| 128 |
-
# Save full response
|
| 129 |
-
txt_path = os.path.join(args.output_dir, f"{file_name}.txt")
|
| 130 |
-
with open(txt_path, "w", encoding="utf-8") as f:
|
| 131 |
-
f.write(response)
|
| 132 |
-
|
| 133 |
-
# Add to CSV data
|
| 134 |
-
csv_data.append([file_name, response, python_code])
|
| 135 |
-
|
| 136 |
-
# Save CSV summary
|
| 137 |
-
csv_path = os.path.join(args.output_dir, "results.csv")
|
| 138 |
-
with open(csv_path, "w", encoding="utf-8", newline="") as f:
|
| 139 |
-
writer = csv.writer(f)
|
| 140 |
-
writer.writerow(["id", "responses", "code"])
|
| 141 |
-
writer.writerows(csv_data)
|
| 142 |
-
|
| 143 |
-
print(f"Results saved to {args.output_dir}")
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
if __name__ == "__main__":
|
| 147 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|