asammoud
Re-add large CSVs using Git LFS
b265364
import numpy as np
import matplotlib.pyplot as plt
import streamlit as st
from PIL import Image
import cv2
import networkx as nx # <-- Added
def build_graph(pil_image, detections, annotations, class_names):
def dist(p1, p2):
return np.hypot(p1[0] - p2[0], p1[1] - p2[1])
def angle_between(p1, p2):
return np.degrees(np.arctan2(p2[1] - p1[1], p2[0] - p1[0])) % 180
def lines_are_similar(line1, line2, max_distance=10, max_angle_diff=10):
(x1, y1), (x2, y2) = line1
(x3, y3), (x4, y4) = line2
angle1 = angle_between((x1, y1), (x2, y2))
angle2 = angle_between((x3, y3), (x4, y4))
if abs(angle1 - angle2) > max_angle_diff:
return False
mid1 = ((x1 + x2) / 2, (y1 + y2) / 2)
mid2 = ((x3 + x4) / 2, (y3 + y4) / 2)
return dist(mid1, mid2) < max_distance
def merge_similar_lines(lines):
if not lines:
return []
merged, used = [], set()
for i, l1 in enumerate(lines):
if i in used: continue
group = [l1]; used.add(i)
for j, l2 in enumerate(lines):
if j != i and j not in used and lines_are_similar(l1, l2):
group.append(l2); used.add(j)
x_coords, y_coords = [], []
for (x1, y1), (x2, y2) in group:
x_coords.extend([x1, x2])
y_coords.extend([y1, y2])
merged.append(((int(min(x_coords)), int(min(y_coords))), (int(max(x_coords)), int(max(y_coords)))))
return merged
def point_inside_bbox(px, py, bbox):
x1, y1, x2, y2 = bbox
return x1 <= px <= x2 and y1 <= py <= y2
def find_nearest_symbol(point, symbols, max_dist=15):
px, py = point
nearest_sym, nearest_dist = None, float('inf')
for sym in symbols:
sx, sy = sym['pos']
d = dist((px, py), (sx, sy))
if d < nearest_dist and d <= max_dist:
nearest_sym, nearest_dist = sym, d
if nearest_sym is None:
for sym in symbols:
if point_inside_bbox(px, py, sym['bbox']):
nearest_sym = sym
break
return nearest_sym
# Convert PIL image to OpenCV format
image_cv = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
# Filter symbols
allowed_types = {"connector", "crossing", "border_node"}
symbols = []
for idx, (box, class_id) in enumerate(zip(detections.xyxy, detections.class_id)):
label = class_names[class_id]
if label in allowed_types:
x1, y1, x2, y2 = map(int, box)
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
symbols.append({
"id": f"{label}_{idx}",
"type": label,
"pos": (cx, cy),
"bbox": (x1, y1, x2, y2)
})
# Hough line detection
gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
blurred = cv2.GaussianBlur(gray, (3, 3), 0)
edges = cv2.Canny(blurred, 50, 150, apertureSize=3)
lines = cv2.HoughLinesP(edges, 1, np.pi / 180, threshold=50, minLineLength=50, maxLineGap=10)
detected_lines = [((x1, y1), (x2, y2)) for line in lines for x1, y1, x2, y2 in line] if lines is not None else []
merged_lines = merge_similar_lines(detected_lines)
filtered_lines = []
for pt1, pt2 in merged_lines:
sym1 = find_nearest_symbol(pt1, symbols)
sym2 = find_nearest_symbol(pt2, symbols)
if sym1 and sym2 and sym1 != sym2:
filtered_lines.append((pt1, pt2))
# Draw results on image
output = image_cv.copy()
for sym in symbols:
x1, y1, x2, y2 = sym["bbox"]
cx, cy = sym["pos"]
cv2.rectangle(output, (x1, y1), (x2, y2), (255, 0, 0), 2)
cv2.circle(output, (cx, cy), 3, (0, 255, 255), -1)
cv2.putText(output, sym["type"], (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, .6, (255, 0, 0), 1)
for (x1, y1), (x2, y2) in filtered_lines:
cv2.line(output, (x1, y1), (x2, y2), (0, 100, 255), 2)
st.image(Image.fromarray(cv2.cvtColor(output, cv2.COLOR_BGR2RGB)),
caption="Graph: Merged Lines + Detected Symbols",
use_column_width=True)
# === Additional: Plot NetworkX graph ===
# Ensure each symbol has a unique ID
for i, sym in enumerate(symbols):
sym['id'] = f"{sym['type']}_{i}"
# Build graph
G = nx.Graph()
for sym in symbols:
G.add_node(sym['id'], label=sym['type'], pos=sym['pos'])
for pt1, pt2 in filtered_lines:
sym1 = find_nearest_symbol(pt1, symbols)
sym2 = find_nearest_symbol(pt2, symbols)
if sym1 and sym2 and sym1['id'] != sym2['id']:
G.add_edge(sym1['id'], sym2['id'])
# Draw NetworkX graph in Streamlit
fig, ax = plt.subplots(figsize=(8, 8))
pos = {node: data['pos'] for node, data in G.nodes(data=True)}
labels = {node: data['label'] for node, data in G.nodes(data=True)}
nx.draw(G, pos, labels=labels, node_size=700, node_color='lightblue',
font_size=8, with_labels=True, ax=ax)
ax.set_title("Extracted Graph from Detected Symbols and Lines")
st.pyplot(fig)
return G