import os import cv2 import base64 import json import pandas as pd import gradio as gr import numpy as np from roboflow import Roboflow from openai import OpenAI import re # ================= CONFIG ================= ROBOFLOW_API_KEY = "uP19IAi98TqwLvHmNB8V" ROBOFLOW_PROJECT = "terminal-block-jtgsl" ROBOFLOW_VERSION = 1 CONF_THRESHOLD = 0.30 IOU_THRESHOLD = 0.4 TERMINAL_JSON_PATH = "terminal.json" client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) rf = Roboflow(api_key=ROBOFLOW_API_KEY) model = rf.workspace().project(ROBOFLOW_PROJECT).version(ROBOFLOW_VERSION).model # ================= LOAD REFERENCE ================= def load_terminal_reference(): if not os.path.exists(TERMINAL_JSON_PATH): return {} try: with open(TERMINAL_JSON_PATH, "r") as f: data = json.load(f) return {str(i["terminal"]).strip().upper(): str(i["wire"]).strip().upper() for i in data.get("terminal_blocks", []) if i.get("wire")} except: return {} terminal_reference = load_terminal_reference() def clean_terminal(text): text = re.sub(r'[^0-9]', '', text) return text def clean_wire(text): text = text.upper().replace(" ", "") # Fix common OCR mistakes text = text.replace("O", "0") text = text.replace("I", "1") text = re.sub(r'[^A-Z0-9]', '', text) return text def is_valid_wire(wire): return bool(re.match(r'^[A-Z]{1,3}[0-9]{2,4}[A-Z]{0,2}$', wire)) def validate_and_fix(t, w): t = clean_terminal(t) w = clean_wire(w) if not t: return None, None if w in ["", "NONE", "N/A"]: w = terminal_reference.get(t, "NONE") if not is_valid_wire(w): if t in terminal_reference: w = terminal_reference[t] return t, w # ================= IMPROVED PREPROCESSING ================= def prepare_for_roboflow(img, max_side=1600): h, w = img.shape[:2] scale = min(max_side / max(h, w), 1) return cv2.resize(img, (int(w * scale), int(h * scale))) if scale < 1 else img def upscale(img): if img.size == 0: return img # High-quality upscale to prevent "11" from blurring into "1" h, w = img.shape[:2] scale = 800 / h if h < 800 else 1.0 return cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_LANCZOS4) def enhance_variants(img): variants = [] if img.size == 0: return variants # Variant 1: Original variants.append(img) # Variant 2: Contrast Enhancement gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) clahe = cv2.createCLAHE(clipLimit=4.0, tileGridSize=(12, 12)) enhanced_gray = clahe.apply(gray) # Variant 3: Denoised & Sharpened (Crucial for thin characters) denoised = cv2.fastNlMeansDenoising(enhanced_gray, None, 10, 7, 21) kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]]) sharpened = cv2.filter2D(denoised, -1, kernel) variants.append(cv2.cvtColor(sharpened, cv2.COLOR_GRAY2BGR)) return variants def img_to_base64(img): _, buffer = cv2.imencode(".jpg", img, [int(cv2.IMWRITE_JPEG_QUALITY), 95]) return base64.b64encode(buffer).decode() # ================= PIPELINE LOGIC ================= def verify(terminal, wire): t, w = terminal.strip().upper(), wire.strip().upper() if t not in terminal_reference: return "UNKNOWN" ref = terminal_reference[t] if w in ["NONE", "EMPTY", "N/A", ""]: return "MATCH" if ref == "NONE" else f"MISSING (Exp {ref})" return "MATCH" if ref == w else f"MISMATCH (Exp {ref})" def fix_missing_wire(terminal, wire): terminal = terminal.strip().upper() wire = wire.strip().upper() # If OCR failed but reference exists → use reference if wire in ["NONE", "", "N/A"]: if terminal in terminal_reference: return terminal_reference[terminal] return wire def group_by_columns(detections, threshold=30): detections = sorted(detections, key=lambda x: x["center"][0]) columns = [] for det in detections: placed = False for col in columns: if abs(col[0]["center"][0] - det["center"][0]) < threshold: col.append(det) placed = True break if not placed: columns.append([det]) return columns def run_pipeline(image): if image is None: return None, pd.DataFrame() img = prepare_for_roboflow(image) H, W = img.shape[:2] # ================= DETECTION ================= preds = model.predict(img, confidence=int(CONF_THRESHOLD * 100)).json()["predictions"] wires, t_nums, w_nums, terms = [], [], [], [] for p in preds: x, y, w, h = map(int, [p["x"], p["y"], p["width"], p["height"]]) det = { "class": p["class"], "bbox": ( max(0, x - w // 2), max(0, y - h // 2), min(W, x + w // 2), min(H, y + h // 2) ), "center": (x, y) } if p["class"] == "Wire": wires.append(det) elif p["class"] == "Terminal Number": t_nums.append(det) elif p["class"] == "Wire Number": w_nums.append(det) elif p["class"] == "Terminal": terms.append(det) # ================= 🔥 NEW COLUMN GROUPING ================= columns = group_by_columns(t_nums + w_nums + terms, threshold=30) ocr_regions = [] for i, col in enumerate(columns): x1 = min(d["bbox"][0] for d in col) y1 = min(d["bbox"][1] for d in col) x2 = max(d["bbox"][2] for d in col) y2 = max(d["bbox"][3] for d in col) pad = 10 ocr_regions.append({ "union_bbox": ( max(0, x1 - pad), max(0, y1 - pad), min(W, x2 + pad), min(H, y2 + pad) ), "id": i }) # ================= GPT PROMPT ================= content = [{ "type": "text", "text": """ STRICT RULES: - One ID = one vertical column - Terminal = number below screws - Wire = text on white sleeve (ILxxx) - NEVER merge columns - NEVER skip digits - If unclear return NONE Output STRICT JSON: [{"id":0,"terminal":"77","wire":"IL23CA"}] """ }] # ================= IMAGE PREP ================= for region in ocr_regions: x1, y1, x2, y2 = region["union_bbox"] roi = img[y1:y2, x1:x2] roi = upscale(roi) content.append({"type": "text", "text": f"id:{region['id']}"}) for v in enhance_variants(roi): content.append({ "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_to_base64(v)}"} }) results = [] # ================= GPT OCR ================= try: response = client.chat.completions.create( model="gpt-4o", messages=[{"role": "user", "content": content}], temperature=0 ) res_text = response.choices[0].message.content match = re.search(r'\[.*\]', res_text, re.DOTALL) if match: parsed = json.loads(match.group()) for item in parsed: idx = item.get("id") if idx is not None and idx < len(ocr_regions): t = str(item.get("terminal", "")).strip() w = str(item.get("wire", "")).strip() t, w = validate_and_fix(t, w) w = fix_missing_wire(t, w) results.append({ "Terminal": t, "Wire": w, "Verification": verify(t, w), "bbox": ocr_regions[idx]["union_bbox"] }) except Exception as e: print(f"Error: {e}") # ================= SORT ================= def safe_int(x): digits = ''.join(filter(str.isdigit, x)) return int(digits) if digits else 999 results = sorted(results, key=lambda x: safe_int(x["Terminal"])) # ================= VISUAL ================= vis = img.copy() for r in results: x1, y1, x2, y2 = r["bbox"] color = (0, 255, 0) if "MATCH" in r["Verification"] else (0, 0, 255) cv2.rectangle(vis, (x1, y1), (x2, y2), color, 2) cv2.putText( vis, f"T:{r['Terminal']}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2 ) return vis, pd.DataFrame(results).drop(columns=["bbox"], errors="ignore") # ================= UI ================= with gr.Blocks(title="Terminal Assembly Inspector") as demo: gr.Markdown("## Terminal Detector ") with gr.Row(): img_in = gr.Image(type="numpy", label="Input Rail") img_out = gr.Image(label="Detections (Red = Error)") btn = gr.Button("Analyze Entire Rail", variant="primary") table = gr.Dataframe(headers=["Terminal", "Wire", "Verification"]) btn.click(run_pipeline, [img_in], [img_out, table]) demo.launch()