Spaces:
Runtime error
Runtime error
| import json | |
| import random | |
| from collections import defaultdict, deque | |
| def generate_localization_samples(n): | |
| all_data = [] | |
| global_index = 1 | |
| def is_all_steps_connected(steps): | |
| # 构建依赖图 | |
| graph = defaultdict(list) | |
| reverse_graph = defaultdict(list) | |
| all_ids = set() | |
| for step in steps: | |
| step_id = step["id"] | |
| inputs = step["inputs"] | |
| all_ids.add(step_id) | |
| for inp in inputs: | |
| if isinstance(inp, int): # 如果引用了前一个 step | |
| graph[inp].append(step_id) | |
| reverse_graph[step_id].append(inp) | |
| # 最后一个 step ID | |
| print(steps) | |
| last_id = steps[-1]["id"] | |
| # 从最后一个 step 开始反向遍历,看能否覆盖所有 step | |
| visited = set() | |
| queue = deque([last_id]) | |
| while queue: | |
| curr = queue.popleft() | |
| visited.add(curr) | |
| for parent in reverse_graph[curr]: | |
| if parent not in visited: | |
| queue.append(parent) | |
| return all_ids.issubset(visited) | |
| while len(all_data) < n: | |
| sample = {"index": global_index, "instruction": "", "steps": []} | |
| num_locations = random.randint(1, 3) | |
| locations = [f"LOC_{i+1}" for i in range(num_locations)] | |
| used_locations = set() | |
| steps = [] | |
| current_id = 1 | |
| all_refs = locations.copy() # step inputs can be LOCs or previous step IDs | |
| step_definitions = [] | |
| num_steps = random.randint(2, 5) | |
| for _ in range(num_steps): | |
| func = random.choice(["Relative", "Azimuth", "Between"]) | |
| if func in ["Relative", "Azimuth"]: | |
| base = random.choice(all_refs) | |
| if isinstance(base, str): | |
| used_locations.add(base) | |
| if func == "Relative": | |
| direction = random.choice([ | |
| "north", "south", "east", "west", | |
| "northeast", "northwest", "southeast", "southwest" | |
| ]) | |
| distance = f"{random.randint(1, 10)} km" | |
| step_definitions.append({ | |
| "id": current_id, | |
| "function": "Relative", | |
| "inputs": [base, direction, distance] | |
| }) | |
| else: | |
| angle = f"{random.randint(0, 359)}°" | |
| distance = f"{random.randint(1, 10)} km" | |
| step_definitions.append({ | |
| "id": current_id, | |
| "function": "Azimuth", | |
| "inputs": [base, angle, distance] | |
| }) | |
| all_refs.append(current_id) | |
| current_id += 1 | |
| elif func == "Between" and len(all_refs) >= 2: | |
| base1, base2 = random.sample(all_refs, 2) | |
| for b in (base1, base2): | |
| if isinstance(b, str): | |
| used_locations.add(b) | |
| step_definitions.append({ | |
| "id": current_id, | |
| "function": "Between", | |
| "inputs": [base1, base2] | |
| }) | |
| all_refs.append(current_id) | |
| current_id += 1 | |
| if len(step_definitions) == 0: | |
| continue # 无有效步骤,跳过重新生成 | |
| all_locs_used = all(loc in used_locations for loc in locations) | |
| steps_connected = is_all_steps_connected(step_definitions) | |
| if all_locs_used and steps_connected: | |
| sample["steps"] = step_definitions | |
| all_data.append(sample) | |
| global_index += 1 | |
| # 否则重新生成 | |
| return all_data | |
| def write_custom_json(data, filename): | |
| def format_step(step): | |
| inputs = json.dumps(step["inputs"], ensure_ascii=False) | |
| return f'{{"id": {step["id"]}, "function": "{step["function"]}", "inputs": {inputs}}}' | |
| with open(filename, "w", encoding="utf-8") as f: | |
| f.write("[\n") | |
| for i, item in enumerate(data): | |
| f.write(" {\n") | |
| f.write(f' "index": {item["index"]},\n') | |
| f.write(' "instruction": "",\n') | |
| f.write(' "steps": [\n') | |
| step_lines = [f" {format_step(step)}" for step in item["steps"]] | |
| f.write(",\n".join(step_lines)) | |
| f.write("\n ]\n") | |
| f.write(" }" + (",\n" if i < len(data) - 1 else "\n")) | |
| f.write("]\n") | |
| # 运行 | |
| if __name__ == "__main__": | |
| samples = generate_localization_samples(100) | |
| write_custom_json(samples, "localization_samples.json") | |
| print("✅ Saved to localization_samples.json with all steps contributing.") | |