| import asyncio |
| from openai import AsyncOpenAI |
| from tqdm import tqdm |
| from load_dataset import load_dataset, length_max |
| from itertools import islice |
| import csv |
|
|
| client = AsyncOpenAI( |
| base_url="http://localhost:8000/v1", |
| api_key="none" |
| ) |
|
|
| |
| scientific_func_schema = { |
| "type": "array", |
| "description": "List of functions related to scientific and especially chemistry-related computing.", |
| "items": { |
| "type": "object", |
| "additionalProperties": False, |
| "properties": { |
| "function_name": { |
| "type": "string", |
| "description": "The function name." |
| }, |
| "function_start_line": { |
| "type": "integer", |
| "description": "The starting line number of the function definition (inclusive)." |
| }, |
| "function_end_line": { |
| "type": "integer", |
| "description": "The ending line number of the function definition (inclusive)." |
| }, |
| "relevance_score": { |
| "type": "integer", |
| "minimum": 0, |
| "maximum": 100, |
| "description": "Relevance score (0–100) for scientific/chemistry-related computing. Only include functions with score > 0." |
| }, |
| "relevance_reason": { |
| "type": "string", |
| "description": "Explanation of why this function is related to scientific/chemical computing and why it received that score." |
| }, |
| "doc_start_line": { |
| "type": ["integer", "null"], |
| "description": "Starting line number of the associated documentation comment, or null if none." |
| }, |
| "doc_end_line": { |
| "type": ["integer", "null"], |
| "description": "Ending line number of the associated documentation comment, or null if none." |
| } |
| }, |
| "required": [ |
| "function_name", |
| "function_start_line", |
| "function_end_line", |
| "relevance_score", |
| "relevance_reason", |
| "doc_start_line", |
| "doc_end_line" |
| ] |
| } |
| } |
|
|
|
|
| async def process_one(code_file): |
| prompt, row = code_file |
| """处理单条 prompt""" |
| resp = await client.chat.completions.create( |
| model="Qwen3", |
| messages=[{"role": "user", "content": prompt}], |
| max_tokens=8192, |
| temperature=0.7, |
| top_p=0.8, |
| presence_penalty=1.5, |
| frequency_penalty=1.5, |
| |
| extra_body={ |
| "top_k": 20, |
| "chat_template_kwargs": { |
| "enable_thinking": False, |
| }, |
|
|
| }, |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| ) |
|
|
| content = resp.choices[0].message.content |
| |
| |
| |
| |
| |
| res = content |
|
|
| return row, res |
|
|
| res_file = open('res2.csv', 'a+', encoding='utf-8') |
| writer = csv.writer(res_file) |
|
|
| async def process_batch(batch): |
| """并发处理一个 batch,同时显示进度条""" |
| |
| tasks = [asyncio.create_task(process_one(p)) for p in batch] |
| results = [] |
|
|
| for f in tqdm(asyncio.as_completed(tasks), total=len(tasks), desc="Processing batch", unit="req", leave=False): |
| result = await f |
| |
| |
| writer.writerow([result[0], result[1]]) |
| results.append(result) |
|
|
| return results |
|
|
| async def process_dataset(dataset_iter, batch_size=200): |
| """按 batch_size 分批处理整个数据集,显示整体进度条""" |
| results = [] |
| num_batches = (length_max + batch_size - 1) // batch_size |
| amount = 0 |
| for i in tqdm(range(num_batches), desc="Overall progress", unit="batch"): |
| batch = list(islice(dataset_iter, batch_size)) |
| batch_results = await process_batch(batch) |
| amount += len(batch_results) |
| |
| print("处理完成,共获得结果条数:", amount) |
| with open("res.log", "w", encoding="utf-8") as f: |
| f.write(str(amount)) |
| return results |
|
|
| if __name__ == "__main__": |
| dataset_iter = load_dataset() |
| final_results = asyncio.run(process_dataset(dataset_iter, batch_size=64)) |
| print("处理完成,共获得结果条数:", len(final_results)) |
| res_file.close() |