GraphGeneratorKIT / graphGen3.py
TahaRasouli's picture
Upload 4 files
1c6109f verified
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import random
import time
class NetworkGenerator:
def __init__(self, size='S', variant='F', topology='highly_connected'):
self.size = size.upper()
self.variant = variant.upper()
self.topology = topology.lower()
if self.topology not in ['highly_connected', 'bottlenecks', 'linear']:
raise ValueError("topology must be: 'highly_connected', 'bottlenecks', or 'linear'")
# Configuration based on size (small, middle, large)
self.size_config = {
'S': {'grid': 4, 'node_factor': 0.4, 'diag_weights': [1, 4]},
'M': {'grid': 8, 'node_factor': 0.4, 'diag_weights': [1, 4]},
'L': {'grid': 16, 'node_factor': 0.4, 'diag_weights': [1, 8]},
}
if self.size not in self.size_config:
raise ValueError("Invalid size. Choose 'S', 'M', or 'L'.")
if self.variant not in ['F', 'R']:
raise ValueError("Invalid variant. Choose 'F' (fixed) or 'R' (random).")
# Scenario setup
self.grid_size = self.size_config[self.size]['grid']
self.node_factor = self.size_config[self.size]['node_factor']
self.weight_dist = self.size_config[self.size]['diag_weights']
# Graph and node storage
self.graph = None
self.nodes_list = None
def generate(self):
"""Generate a connected network representing rooms in a building."""
max_attempts = 5 # retry limit
for attempt in range(max_attempts):
self._initialize_graph()
self._add_nodes()
nodes = list(self.graph.nodes())
if not nodes:
continue
# --- STEP 1: CONNECTIVITY (NEARBY ROOMS ONLY) ---
connected = set()
remaining = set(nodes)
# Start with a random initial room
current = random.choice(nodes)
connected.add(current)
remaining.remove(current)
while remaining:
# Candidate rooms: within distance <= 2 of ANY connected room
candidates = [
n for n in remaining
if any(abs(n[0] - c[0]) <= 2 and abs(n[1] - c[1]) <= 2 for c in connected)
]
if candidates:
candidate = random.choice(candidates)
else:
# fallback: pick any unconnected room
candidate = random.choice(list(remaining))
# Find connected neighbors near the candidate
neighbors = [
c for c in connected
if abs(c[0] - candidate[0]) <= 2 and abs(c[1] - candidate[1]) <= 2
]
if neighbors:
n = random.choice(neighbors)
else:
# fallback: ANY connected node
n = random.choice(list(connected))
# --- Intersection checks ---
valid = True
# Straight edge
if n[0] == candidate[0] or n[1] == candidate[1]:
if self._straight_edge_intersects(n, candidate):
valid = False
# Diagonal edge
elif abs(n[0] - candidate[0]) == abs(n[1] - candidate[1]):
if self._diagonal_intersects(n, candidate):
valid = False
else:
# Not straight or diagonal → forced but accepted
valid = False
# Add the edge anyway (forced connectivity)
self.graph.add_edge(n, candidate)
# Mark candidate as connected
connected.add(candidate)
remaining.remove(candidate)
# --- STEP 2: ADD TOPOLOGY-SPECIFIC EXTRA EDGES ---
self._add_edges()
# --- STEP 3: REMOVE INTERSECTIONS & RECONNECT ---
self._remove_intersections()
# --- STEP 4: FINAL CONNECTIVITY CHECK ---
if nx.is_connected(self.graph):
return self.graph
raise RuntimeError("Failed to generate a connected network after several attempts")
def _initialize_graph(self):
self.graph = nx.Graph()
# Start in the middle region instead of (0,0)
margin = max(1, self.grid_size // 4)
low, high = margin, self.grid_size - margin
x = random.randint(low, high)
y = random.randint(low, high)
coords = np.array([x, y])
flags = np.zeros(4, dtype=int)
self.nodes_list = [[coords, flags]]
self.graph.add_node(tuple(coords))
def _compute_nodes(self):
total_possible = (self.grid_size + 1) ** 2
if self.variant == 'F':
return int(self.node_factor * total_possible)
else:
return int(random.uniform(0.4, 0.7) * total_possible)
def _add_nodes(self):
"""Place nodes mostly in the middle region (cluster logic)."""
total_nodes = self._compute_nodes()
# Middle region boundaries
margin = max(1, self.grid_size // 4)
low, high = margin, self.grid_size - margin
attempts = 0
while len(self.graph.nodes()) < total_nodes and attempts < 5000:
attempts += 1
x = random.randint(low, high)
y = random.randint(low, high)
if (x, y) not in self.graph:
self.graph.add_node((x, y))
def _add_random_neighbors(self):
if not self.nodes_list:
return
predecessor_entry = self.nodes_list[0]
coords, _ = predecessor_entry
rand_neighbors = random.randint(1, 4)
for _ in range(rand_neighbors):
direction = random.choice(['V', 'H'])
distance = random.choices([1, 2], weights=self.weight_dist, k=1)[0]
new_coords = self._get_new_node(coords, direction, distance)
if new_coords is not None and tuple(new_coords) not in self.graph:
self.graph.add_node(tuple(new_coords))
flags = np.zeros(4, dtype=int)
self.nodes_list.append([new_coords, flags])
self._update_neighbor_flags(coords, new_coords)
self.nodes_list.pop(0)
def _get_new_node(self, coords, direction, dist):
x, y = coords
if direction == 'V':
if random.choice([True, False]) and x + dist <= self.grid_size:
return np.array([x + dist, y])
elif x - dist >= 0:
return np.array([x - dist, y])
elif direction == 'H':
if random.choice([True, False]) and y + dist <= self.grid_size:
return np.array([x, y + dist])
elif y - dist >= 0:
return np.array([x, y - dist])
return None
def _update_neighbor_flags(self, predecessor_coords, new_coords):
px, py = predecessor_coords
nx_, ny = new_coords
# Find indices
predecessor_idx = next((i for i, n in enumerate(self.nodes_list) if np.array_equal(n[0], predecessor_coords)), None)
new_node_idx = next((i for i, n in enumerate(self.nodes_list) if np.array_equal(n[0], new_coords)), None)
if predecessor_idx is None or new_node_idx is None:
return
# Directional flags: [up, down, left, right]
if nx_ < px: # new above
self.nodes_list[predecessor_idx][1][0] = 1
self.nodes_list[new_node_idx][1][1] = 1
elif nx_ > px: # new below
self.nodes_list[predecessor_idx][1][1] = 1
self.nodes_list[new_node_idx][1][0] = 1
elif ny < py: # new left
self.nodes_list[predecessor_idx][1][2] = 1
self.nodes_list[new_node_idx][1][3] = 1
elif ny > py: # new right
self.nodes_list[predecessor_idx][1][3] = 1
self.nodes_list[new_node_idx][1][2] = 1
def _compute_edge_count(self):
total_nodes = len(self.graph.nodes())
if self.variant == 'F':
return int(1.5 * total_nodes)
else:
return int(random.uniform(1.5, 2.5) * total_nodes)
def _add_edges(self):
nodes = list(self.graph.nodes())
total_edges = self._compute_edge_count()
if self.topology == "highly_connected":
self._add_cluster_dense(nodes, total_edges)
elif self.topology == "bottlenecks":
self._add_cluster_sparse(nodes, total_edges)
self._add_cluster_bottleneck(nodes)
elif self.topology == "linear":
self._make_linear(nodes)
def _add_straight_edges_if_no_intersection(self, nodes, max_edges):
count = 0
for i in range(len(nodes)):
for j in range(i + 1, len(nodes)):
if count >= max_edges:
return
x1, y1 = nodes[i]
x2, y2 = nodes[j]
if (x1 == x2 or y1 == y2) and not self.graph.has_edge(nodes[i], nodes[j]):
self.graph.add_edge(nodes[i], nodes[j])
count += 1
def _straight_edge_intersects(self, n1, n2):
"""Check if a straight (H/V) edge between n1–n2 intersects existing edges."""
x1, y1 = n1
x2, y2 = n2
# Only straight edges
if not (x1 == x2 or y1 == y2):
return True
# Ensure consistent ordering
if (x1, y1) > (x2, y2):
n1, n2 = n2, n1
x1, y1 = n1
x2, y2 = n2
for a, b in self.graph.edges():
if {a, b} == {n1, n2}:
continue
ax, ay = a
bx, by = b
# Horizontal edge
if y1 == y2:
if ay == by == y1:
# overlap?
if max(ax, bx) >= min(x1, x2) and min(ax, bx) <= max(x1, x2):
return True
# Vertical edge
if x1 == x2:
if ax == bx == x1:
if max(ay, by) >= min(y1, y2) and min(ay, by) <= max(y1, y2):
return True
return False
def _diagonal_intersects(self, n1, n2):
x1, y1 = n1
x2, y2 = n2
for a, b in self.graph.edges():
ax, ay = a
bx, by = b
# Only check against diagonal edges
if abs(ax - bx) == abs(ay - by):
# Check if bounding boxes overlap
if not (max(x1, x2) < min(ax, bx) or min(x1, x2) > max(ax, bx)):
if not (max(y1, y2) < min(ay, by) or min(y1, y2) > max(ay, by)):
return True
return False
def _generate_diagonal_edges(self, nodes, max_edges):
count = 0
for i in range(len(nodes)):
for j in range(i + 1, len(nodes)):
if count >= max_edges:
return
x1, y1 = nodes[i]
x2, y2 = nodes[j]
if abs(x1 - x2) == abs(y1 - y2) and not self.graph.has_edge(nodes[i], nodes[j]):
self.graph.add_edge(nodes[i], nodes[j])
count += 1
def _make_linear(self, nodes):
# Sort nodes by x then by y so the backbone moves roughly top→down or left→right
nodes_sorted = sorted(nodes, key=lambda x: (x[0], x[1]))
# Build the main backbone (no diagonal, only straight)
prev = nodes_sorted[0]
for nxt in nodes_sorted[1:]:
x1, y1 = prev
x2, y2 = nxt
# ONLY connect if same row or same column
if x1 == x2 or y1 == y2:
self.graph.add_edge(prev, nxt)
prev = nxt
else:
# If diagonal, find a 1-step straight intermediate
# Move horizontally first
if x1 != x2:
step = (x1 + (1 if x2 > x1 else -1), y1)
if step in nodes:
self.graph.add_edge(prev, step)
self.graph.add_edge(step, nxt)
prev = nxt
continue
# Move vertically
if y1 != y2:
step = (x1, y1 + (1 if y2 > y1 else -1))
if step in nodes:
self.graph.add_edge(prev, step)
self.graph.add_edge(step, nxt)
prev = nxt
continue
# Add occasional side branches (0.15 = 15% chance)
for node in nodes_sorted:
if random.random() < 0.15:
x, y = node
# choose one of the 4 permissible directions
candidates = [(x+1,y),(x-1,y),(x,y+1),(x,y-1)]
random.shuffle(candidates)
for c in candidates:
if c in nodes and not self.graph.has_edge(node, c):
# Ensure node doesn't exceed degree 3
if self.graph.degree(node) < 3 and self.graph.degree(c) < 3:
self.graph.add_edge(node, c)
break
def _add_sparse_edges(self, nodes):
# create a moderate number of edges but not dense
for i in range(len(nodes)):
for j in range(i+1, len(nodes)):
if random.random() < 0.15: # sparse edges
self.graph.add_edge(nodes[i], nodes[j])
def _create_bottleneck(self, nodes):
# Split graph into left/right sets (or top/bottom)
left = [n for n in nodes if n[0] <= self.grid_size // 2]
right = [n for n in nodes if n not in left]
# pick random chokepoint nodes
l = random.choice(left)
r = random.choice(right)
# force 1 bottleneck edge
self.graph.add_edge(l, r)
def _add_dense_edges(self, nodes):
# add all straight edges
for i in range(len(nodes)):
for j in range(i+1, len(nodes)):
x1, y1 = nodes[i]
x2, y2 = nodes[j]
# Straight connections
if x1 == x2 or y1 == y2:
self.graph.add_edge(nodes[i], nodes[j])
# Diagonal connections
if abs(x1 - x2) == abs(y1 - y2):
self.graph.add_edge(nodes[i], nodes[j])
def _add_cluster_dense(self, nodes, max_edges):
edges_added = 0
random.shuffle(nodes)
for i in range(len(nodes)):
for j in range(i+1, len(nodes)):
if edges_added >= max_edges:
return
n1, n2 = nodes[i], nodes[j]
# Straight edge
if (n1[0] == n2[0] or n1[1] == n2[1]):
if not self._straight_edge_intersects(n1, n2):
self.graph.add_edge(n1, n2)
edges_added += 1
continue
# Diagonal
if abs(n1[0] - n2[0]) == abs(n1[1] - n2[1]):
if not self._diagonal_intersects(n1, n2):
self.graph.add_edge(n1, n2)
edges_added += 1
def _add_cluster_sparse(self, nodes, max_edges):
edges_added = 0
random.shuffle(nodes)
for i in range(len(nodes)):
for j in range(i+1, len(nodes)):
if edges_added >= max_edges:
return
if random.random() < 0.15: # sparse like your C
n1, n2 = nodes[i], nodes[j]
# straight only for sparsity
if (n1[0] == n2[0] or n1[1] == n2[1]) and \
not self._straight_edge_intersects(n1, n2):
self.graph.add_edge(n1, n2)
edges_added += 1
def _add_cluster_bottleneck(self, nodes):
mid = self.grid_size // 2
left = [n for n in nodes if n[0] <= mid]
right = [n for n in nodes if n not in left]
if not left or not right:
return
a = random.choice(left)
b = random.choice(right)
if not self._straight_edge_intersects(a, b):
self.graph.add_edge(a, b)
# --------------------
# Intersection utilities
# --------------------
def _orientation(self, p, q, r):
"""Return orientation for ordered triplet (p, q, r).
0 = collinear, 1 = clockwise, 2 = counterclockwise."""
(px, py), (qx, qy), (rx, ry) = p, q, r
val = (qy - py) * (rx - qx) - (qx - px) * (ry - qy)
if val == 0:
return 0
return 1 if val > 0 else 2
def _on_segment(self, p, q, r):
"""Check if point q lies on segment pr."""
(px, py), (qx, qy), (rx, ry) = p, q, r
return (min(px, rx) <= qx <= max(px, rx) and
min(py, ry) <= qy <= max(py, ry))
def _segments_intersect(self, a, b, c, d):
"""Return True if segments ab and cd intersect (excluding shared endpoints)."""
# Shared endpoints do NOT count as intersections
if a in (c, d) or b in (c, d):
return False
o1 = self._orientation(a, b, c)
o2 = self._orientation(a, b, d)
o3 = self._orientation(c, d, a)
o4 = self._orientation(c, d, b)
# General case
if o1 != o2 and o3 != o4:
return True
# Special cases (collinear)
if o1 == 0 and self._on_segment(a, c, b):
return True
if o2 == 0 and self._on_segment(a, d, b):
return True
if o3 == 0 and self._on_segment(c, a, d):
return True
if o4 == 0 and self._on_segment(c, b, d):
return True
return False
def _would_create_intersection(self, u, v):
"""Check whether adding edge (u,v) would intersect any existing edge."""
for x, y in self.graph.edges():
# ignore if touching endpoints
if u in (x, y) or v in (x, y):
continue
if self._segments_intersect(u, v, x, y):
return True
return False
def _remove_intersections(self):
"""
Remove intersecting edges and attempt to reconnect components using
nearest-neighbor edges (prefer Chebyshev distance <= 2 as requested).
"""
max_passes = 10
pass_no = 0
total_removed = 0
while pass_no < max_passes:
pass_no += 1
edges = list(self.graph.edges())
intersections = []
# Find all intersecting edge pairs
for i in range(len(edges)):
a, b = edges[i]
for j in range(i + 1, len(edges)):
c, d = edges[j]
if self._segments_intersect(a, b, c, d):
intersections.append((a, b, c, d))
if not intersections:
break # no intersections left
# Remove longer edge of each intersecting pair (if still present)
removed_this_pass = 0
for a, b, c, d in intersections:
if not self.graph.has_edge(a, b) or not self.graph.has_edge(c, d):
continue # already removed in this pass
len1 = (a[0]-b[0])**2 + (a[1]-b[1])**2
len2 = (c[0]-d[0])**2 + (c[1]-d[1])**2
if len1 >= len2:
try:
self.graph.remove_edge(a, b)
removed_this_pass += 1
except Exception:
pass
else:
try:
self.graph.remove_edge(c, d)
removed_this_pass += 1
except Exception:
pass
total_removed += removed_this_pass
# After removals, try to reconnect components
self._attempt_reconnect_components(prefer_max_distance=2)
# Final try to reconnect if still disconnected
if not nx.is_connected(self.graph):
self._attempt_reconnect_components(prefer_max_distance=self.grid_size)
# One last pass to remove any intersections created during reconnection attempts
# but limit passes to avoid endless loops
final_edges = list(self.graph.edges())
for i in range(len(final_edges)):
a, b = final_edges[i]
for j in range(i+1, len(final_edges)):
c, d = final_edges[j]
if self._segments_intersect(a, b, c, d):
# break ties by removing longer edge
len1 = (a[0]-b[0])**2 + (a[1]-b[1])**2
len2 = (c[0]-d[0])**2 + (c[1]-d[1])**2
if len1 >= len2 and self.graph.has_edge(a,b):
self.graph.remove_edge(a, b)
total_removed += 1
elif self.graph.has_edge(c,d):
self.graph.remove_edge(c, d)
total_removed += 1
# Debug / informative print
# (You can replace prints with logging if preferred)
print(f"[cleanup] Removed {total_removed} intersecting edges after {pass_no} passes.")
def _attempt_reconnect_components(self, prefer_max_distance=2):
"""
Try to connect disconnected components by adding edges between the closest
node pairs across components. Preference: Chebyshev distance <= prefer_max_distance,
gradually relaxing up to grid_size if required. Avoid creating intersections when possible.
"""
comps = list(nx.connected_components(self.graph))
if len(comps) <= 1:
return
# Function to compute Chebyshev distance
def cheb(a, b):
return max(abs(a[0]-b[0]), abs(a[1]-b[1]))
# Build list of nodes per component
comp_nodes = [list(c) for c in comps]
# We'll try to connect components pairwise until a single component remains.
# Attempt multiple relaxation levels.
max_relax = self.grid_size
relax = prefer_max_distance
while relax <= max_relax and len(comp_nodes) > 1:
made_connection = False
# Try connecting each pair of components
i = 0
while i < len(comp_nodes) - 1:
j = i + 1
connected_this_round = False
while j < len(comp_nodes):
best_pair = None
best_dist = None
# find best node pair between comp i and comp j within relax
for u in comp_nodes[i]:
for v in comp_nodes[j]:
if u == v:
continue
d = cheb(u, v)
if d <= relax and (best_dist is None or d < best_dist):
best_pair = (u, v)
best_dist = d
if best_pair is not None:
u, v = best_pair
# avoid adding duplicate edge
if not self.graph.has_edge(u, v):
# prefer adding if it won't create intersection
if not self._would_create_intersection(u, v):
self.graph.add_edge(u, v)
made_connection = True
connected_this_round = True
# merge components lists
comp_nodes[i].extend(comp_nodes[j])
comp_nodes.pop(j)
break
else:
# If we cannot avoid intersection, try to find alternative pairs
# Try other candidate pairs within same two comps
alt_added = False
for uu in comp_nodes[i]:
for vv in comp_nodes[j]:
if uu == vv:
continue
d2 = cheb(uu, vv)
if d2 <= relax and not self.graph.has_edge(uu, vv):
if not self._would_create_intersection(uu, vv):
self.graph.add_edge(uu, vv)
alt_added = True
break
if alt_added:
break
if alt_added:
made_connection = True
connected_this_round = True
comp_nodes[i].extend(comp_nodes[j])
comp_nodes.pop(j)
break
else:
# as final resort, add the best_pair even if it creates intersection
# This ensures connectivity; intersections will be cleaned in a later pass.
self.graph.add_edge(u, v)
made_connection = True
connected_this_round = True
comp_nodes[i].extend(comp_nodes[j])
comp_nodes.pop(j)
break
else:
# no candidate between these two comps within relax
j += 1
if not connected_this_round:
i += 1 # move to next comp pair to try
# if connected_this_round we keep i same to attempt merging more into same comp
if not made_connection:
relax += 1 # relax distance constraint and try again
else:
# recompute components after merges
comps = list(nx.connected_components(self.graph))
comp_nodes = [list(c) for c in comps]
# End while: either connected or we've exhausted relax limit
def plot(self):
plt.figure(figsize=(8, 8))
pos = {node: (node[1], -node[0]) for node in self.graph.nodes()}
nx.draw(self.graph, pos, with_labels=True, node_size=300, font_size=8)
plt.title(f"Generated Network ({self.size}, {self.variant})")
plt.grid(True)
plt.show()