| import os |
| import cv2 |
| import base64 |
| import json |
| import pandas as pd |
| import gradio as gr |
| import numpy as np |
| from roboflow import Roboflow |
| from openai import OpenAI |
| from concurrent.futures import ThreadPoolExecutor |
| import re |
|
|
| |
| ROBOFLOW_API_KEY = "uP19IAi98TqwLvHmNB8V" |
| ROBOFLOW_PROJECT = "terminal-block-jtgsl" |
| ROBOFLOW_VERSION = 1 |
| CONF_THRESHOLD = 0.30 |
| IOU_THRESHOLD = 0.4 |
| TERMINAL_JSON_PATH = "terminal.json" |
|
|
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
| rf = Roboflow(api_key=ROBOFLOW_API_KEY) |
| model = rf.workspace().project(ROBOFLOW_PROJECT).version(ROBOFLOW_VERSION).model |
|
|
| |
| def load_terminal_reference(): |
| if not os.path.exists(TERMINAL_JSON_PATH): return {} |
| try: |
| with open(TERMINAL_JSON_PATH, "r") as f: |
| data = json.load(f) |
| return {str(i["terminal"]).strip().upper(): str(i["wire"]).strip().upper() |
| for i in data.get("terminal_blocks", []) if i.get("wire")} |
| except: return {} |
|
|
| terminal_reference = load_terminal_reference() |
|
|
| def clean_terminal(text): |
| text = re.sub(r'[^0-9]', '', text) |
| return text |
|
|
| def clean_wire(text): |
| text = text.upper().replace(" ", "") |
| |
| |
| text = text.replace("O", "0") |
| text = text.replace("I", "1") |
|
|
| text = re.sub(r'[^A-Z0-9]', '', text) |
| return text |
|
|
| def is_valid_wire(wire): |
| return bool(re.match(r'^[A-Z]{1,3}[0-9]{2,4}[A-Z]{0,2}$', wire)) |
|
|
| def validate_and_fix(t, w): |
| t = clean_terminal(t) |
| w = clean_wire(w) |
|
|
| if not t: |
| return None, None |
|
|
| if w in ["", "NONE", "N/A"]: |
| w = terminal_reference.get(t, "NONE") |
|
|
| if not is_valid_wire(w): |
| if t in terminal_reference: |
| w = terminal_reference[t] |
|
|
| return t, w |
| |
| def prepare_for_roboflow(img, max_side=1600): |
| h, w = img.shape[:2] |
| scale = min(max_side / max(h, w), 1) |
| return cv2.resize(img, (int(w * scale), int(h * scale))) if scale < 1 else img |
|
|
| def upscale(img): |
| if img.size == 0: return img |
| |
| h, w = img.shape[:2] |
| scale = 1000 / h if h < 1000 else 1.0 |
| return cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) |
|
|
| def enhance_variants(img): |
| variants = [] |
| if img.size == 0: return variants |
| |
| |
| variants.append(img) |
| |
| |
| gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) |
| clahe = cv2.createCLAHE(clipLimit=4.0, tileGridSize=(12, 12)) |
| enhanced_gray = clahe.apply(gray) |
| |
| |
| denoised = cv2.fastNlMeansDenoising(enhanced_gray, None, 10, 7, 21) |
| kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]) |
| sharpened = cv2.filter2D(denoised, -1, kernel) |
| variants.append(cv2.cvtColor(sharpened, cv2.COLOR_GRAY2BGR)) |
| |
| return variants |
|
|
| def img_to_base64(img): |
| _, buffer = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), 85]) |
| return base64.b64encode(buffer).decode() |
|
|
| from concurrent.futures import ThreadPoolExecutor |
|
|
| |
| def verify(terminal, wire): |
| t, w = terminal.strip().upper(), wire.strip().upper() |
| if t not in terminal_reference: return "UNKNOWN" |
| ref = terminal_reference[t] |
| if w in ["NONE", "EMPTY", "N/A", ""]: |
| return "MATCH" if ref == "NONE" else f"MISSING (Exp {ref})" |
| return "MATCH" if ref == w else f"MISMATCH (Exp {ref})" |
|
|
| def fix_missing_wire(terminal, wire): |
| terminal = terminal.strip().upper() |
| wire = wire.strip().upper() |
|
|
| if wire in ["NONE", "", "N/A"]: |
| if terminal in terminal_reference: |
| return terminal_reference[terminal] |
|
|
| return wire |
|
|
| def group_by_columns(detections, threshold=30): |
| detections = sorted(detections, key=lambda x: x["center"][0]) |
| columns = [] |
|
|
| for det in detections: |
| placed = False |
| for col in columns: |
| if abs(col[0]["center"][0] - det["center"][0]) < threshold: |
| col.append(det) |
| placed = True |
| break |
| if not placed: |
| columns.append([det]) |
|
|
| return columns |
|
|
|
|
| |
| def process_region_batch(batch_regions, img): |
| content = [{ |
| "type": "text", |
| "text": """ |
| STRICT RULES: |
| - One ID = one vertical column |
| - Terminal = number below screws |
| - Wire = text on white sleeve (ILxxx) |
| - NEVER merge columns |
| - NEVER skip digits |
| - If unclear return NONE |
| |
| Output STRICT JSON: |
| [{"id":0,"terminal":"77","wire":"IL23CA"}] |
| """ |
| }] |
|
|
| for region in batch_regions: |
| x1, y1, x2, y2 = region["union_bbox"] |
|
|
| roi = img[y1:y2, x1:x2] |
|
|
| if roi.size == 0 or roi.shape[0] < 25 or roi.shape[1] < 25: |
| continue |
|
|
| if roi.shape[0] > 400: |
| scale = 400 / roi.shape[0] |
| roi = cv2.resize(roi, None, fx=scale, fy=scale) |
|
|
| roi = upscale(roi) |
|
|
| content.append({"type": "text", "text": f"id:{region['id']}"}) |
|
|
| for v in enhance_variants(roi): |
| content.append({ |
| "type": "image_url", |
| "image_url": {"url": f"data:image/jpeg;base64,{img_to_base64(v)}"} |
| }) |
|
|
| try: |
| response = client.chat.completions.create( |
| model="gpt-4o", |
| messages=[{"role": "user", "content": content}], |
| temperature=0 |
| ) |
|
|
| res_text = response.choices[0].message.content |
| match = re.search(r'\[\s*{.*?}\s*\]', res_text, re.DOTALL) |
|
|
| if match: |
| return json.loads(match.group()) |
|
|
| except Exception as e: |
| print("Batch error:", e) |
|
|
| return [] |
|
|
|
|
| def run_pipeline(image): |
| if image is None: |
| return None, pd.DataFrame() |
|
|
| img = prepare_for_roboflow(image) |
| H, W = img.shape[:2] |
|
|
| |
| preds = model.predict(img, confidence=int(CONF_THRESHOLD * 100)).json()["predictions"] |
| preds = [p for p in preds if p["confidence"] > 0.35] |
|
|
| wires, t_nums, w_nums, terms = [], [], [], [] |
|
|
| for p in preds: |
| x, y, w, h = map(int, [p["x"], p["y"], p["width"], p["height"]]) |
|
|
| det = { |
| "class": p["class"], |
| "bbox": ( |
| max(0, x - w // 2), |
| max(0, y - h // 2), |
| min(W, x + w // 2), |
| min(H, y + h // 2) |
| ), |
| "center": (x, y) |
| } |
|
|
| if p["class"] == "Wire": |
| wires.append(det) |
| elif p["class"] == "Terminal Number": |
| t_nums.append(det) |
| elif p["class"] == "Wire Number": |
| w_nums.append(det) |
| elif p["class"] == "Terminal": |
| terms.append(det) |
|
|
| |
| columns = group_by_columns(t_nums + w_nums + terms, threshold=30) |
|
|
| ocr_regions = [] |
|
|
| for i, col in enumerate(columns): |
| x1 = min(d["bbox"][0] for d in col) |
| y1 = min(d["bbox"][1] for d in col) |
| x2 = max(d["bbox"][2] for d in col) |
| y2 = max(d["bbox"][3] for d in col) |
|
|
| pad = 10 |
|
|
| ocr_regions.append({ |
| "union_bbox": ( |
| max(0, x1 - pad), |
| max(0, y1 - pad), |
| min(W, x2 + pad), |
| min(H, y2 + pad) |
| ), |
| "id": i |
| }) |
|
|
| |
| batch_size = 6 |
| batches = [ocr_regions[i:i+batch_size] for i in range(0, len(ocr_regions), batch_size)] |
|
|
| results = [] |
|
|
| with ThreadPoolExecutor(max_workers=4) as executor: |
| futures = [executor.submit(process_region_batch, batch, img) for batch in batches] |
|
|
| for future in futures: |
| batch_output = future.result() |
|
|
| for item in batch_output: |
| idx = item.get("id") |
|
|
| if idx is not None and idx < len(ocr_regions): |
| t = str(item.get("terminal", "")).strip() |
| w = str(item.get("wire", "")).strip() |
|
|
| t, w = validate_and_fix(t, w) |
| w = fix_missing_wire(t, w) |
|
|
| results.append({ |
| "Terminal": t, |
| "Wire": w, |
| "Verification": verify(t, w), |
| "bbox": ocr_regions[idx]["union_bbox"] |
| }) |
|
|
| |
| def safe_int(x): |
| digits = ''.join(filter(str.isdigit, x)) |
| return int(digits) if digits else 999 |
|
|
| results = sorted(results, key=lambda x: safe_int(x["Terminal"])) |
|
|
| |
| vis = img.copy() |
|
|
| for r in results: |
| x1, y1, x2, y2 = r["bbox"] |
|
|
| color = (0, 255, 0) if "MATCH" in r["Verification"] else (0, 0, 255) |
|
|
| cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2) |
| cv2.putText( |
| vis, |
| f"T:{r['Terminal']}", |
| (x1, y1 - 10), |
| cv2.FONT_HERSHEY_SIMPLEX, |
| 0.6, |
| color, |
| 2 |
| ) |
|
|
| df = pd.DataFrame(results).drop(columns=["bbox"], errors="ignore") |
|
|
| if df.empty: |
| ocr_df = pd.DataFrame(columns=["Terminal", "Wire"]) |
| verify_df = pd.DataFrame(columns=["Terminal", "Wire", "Verification"]) |
| else: |
| ocr_df = df[["Terminal", "Wire"]].copy() |
| verify_df = df[["Terminal", "Wire", "Verification"]].copy() |
| |
| file_path = "terminal_result.xlsx" |
| verify_df.to_excel(file_path, index=False) |
| |
| status = "✅ Verification Completed" if not df.empty else "⚠️ No detections found" |
| |
| return vis, ocr_df, verify_df, file_path, status |
|
|
| apple_dark_pink_css = """ |
| @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600&display=swap'); |
| .gradio-container { |
| background: #0f1115; |
| font-family: 'Outfit', sans-serif; |
| } |
| /* Headings */ |
| h1 { |
| color: #f9fafb; |
| font-weight: 600; |
| letter-spacing: 0.5px; |
| } |
| h2, h3 { |
| color: #e5e7eb; |
| font-weight: 500; |
| } |
| /* Cards */ |
| .gr-box { |
| background: #161a22; |
| border-radius: 16px; |
| padding: 12px; |
| } |
| /* Primary Button */ |
| button.primary { |
| background: #f472b6 !important; |
| color: #020617 !important; |
| border-radius: 12px; |
| font-weight: 500; |
| letter-spacing: 0.5px; |
| transition: all 0.2s ease; |
| } |
| button.primary:hover { |
| background: #ec4899 !important; |
| transform: translateY(-1px); |
| } |
| /* Inputs */ |
| input, textarea, select { |
| font-family: 'Outfit', sans-serif !important; |
| border-radius: 10px !important; |
| } |
| """ |
|
|
| with gr.Blocks( |
| theme=gr.themes.Soft(primary_hue="pink"), |
| css=apple_dark_pink_css |
| ) as demo: |
|
|
| gr.Markdown("# AI-Based Verification of Wire Connections in Terminal Block Cubicles") |
|
|
| with gr.Row(): |
| img_in = gr.Image(type="numpy", label="Upload Image") |
| img_out = gr.Image(label="Detected Image") |
|
|
| btn = gr.Button("Run Verification", variant="primary") |
|
|
| status_msg = gr.Markdown() |
|
|
| t1 = gr.Dataframe(label="OCR Output") |
| t2 = gr.Dataframe(label="Verification Result") |
| f = gr.File(label="Download Result") |
| |
| btn.click( |
| run_pipeline, |
| [img_in], |
| [img_out, t1, t2, f, status_msg] |
| ) |
|
|
| demo.launch() |
|
|
|
|
|
|
|
|
|
|