Update app.py
Browse files
app.py
CHANGED
|
@@ -1,16 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os, cv2, re, base64
|
| 2 |
import numpy as np
|
| 3 |
import pandas as pd
|
| 4 |
import gradio as gr
|
|
|
|
| 5 |
from roboflow import Roboflow
|
| 6 |
from openai import OpenAI
|
| 7 |
from openpyxl import load_workbook
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# ================= CONFIG =================
|
| 10 |
|
| 11 |
ROBOFLOW_API_KEY = "uP19IAi98TqwLvHmNB8V"
|
| 12 |
ROBOFLOW_PROJECT = "braker3"
|
| 13 |
-
ROBOFLOW_VERSION =
|
| 14 |
CONF_THRESHOLD = 0.35
|
| 15 |
IOU_THRESHOLD = 0.4
|
| 16 |
|
|
@@ -19,9 +29,119 @@ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
| 19 |
rf = Roboflow(api_key=ROBOFLOW_API_KEY)
|
| 20 |
model = rf.workspace().project(ROBOFLOW_PROJECT).version(ROBOFLOW_VERSION).model
|
| 21 |
|
| 22 |
-
CROP_DIR = "
|
| 23 |
os.makedirs(CROP_DIR, exist_ok=True)
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
# ================= CONSTANTS =================
|
| 26 |
|
| 27 |
CIRCUIT_PATTERN = r"(?:\d+L\d+-\d+|S\d+)"
|
|
@@ -30,13 +150,37 @@ DEFAULT_BREAKING_CAPACITY = "85"
|
|
| 30 |
|
| 31 |
VALID_AF_VALUES = {"50","63","100","125","160","250","400","630"}
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
SPEC_JP = {
|
| 34 |
"Manufacture Name": "メーカー",
|
| 35 |
"Circuit Name": "回路番号",
|
| 36 |
"Load Name": "負荷名称",
|
| 37 |
"Breaking Capacity": "遮断容量",
|
| 38 |
"AF": "フレーム(AF)",
|
| 39 |
-
"AT": "トリップ(AT)"
|
|
|
|
|
|
|
|
|
|
| 40 |
}
|
| 41 |
|
| 42 |
MANUFACTURER_JP_MAP = {
|
|
@@ -47,6 +191,7 @@ MANUFACTURER_JP_MAP = {
|
|
| 47 |
"LS ELECTRIC": "LS ELECTRIC"
|
| 48 |
}
|
| 49 |
|
|
|
|
| 50 |
KNOWN_MANUFACTURERS = {
|
| 51 |
"MITSUBISHI ELECTRIC",
|
| 52 |
"SIEMENS",
|
|
@@ -75,7 +220,7 @@ def crop_with_optional_expand(img, x1, y1, x2, y2, label):
|
|
| 75 |
return img[max(0,y1):min(h,y2), max(0,x1):min(w,x2)]
|
| 76 |
|
| 77 |
def upscale(img):
|
| 78 |
-
return cv2.resize(img, None, fx=
|
| 79 |
|
| 80 |
def rotate_image(img, a):
|
| 81 |
if a == 90: return cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
|
|
@@ -109,6 +254,73 @@ def enhance_AT(img):
|
|
| 109 |
sharp = cv2.addWeighted(img, 1.5, blur, -0.5, 0)
|
| 110 |
return sharp
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
def save_crop(label, img):
|
| 113 |
cv2.imwrite(os.path.join(CROP_DIR, f"{label}.jpg"), img)
|
| 114 |
|
|
@@ -142,19 +354,37 @@ def normalize_for_compare(t):
|
|
| 142 |
|
| 143 |
def gpt_single_ocr(label, img):
|
| 144 |
b64 = img_to_base64(img)
|
|
|
|
| 145 |
rules = {
|
| 146 |
"Manufacture Name": "Read manufacturer name in English only.",
|
| 147 |
"Circuit Name": "Read the FULL text exactly as printed.",
|
| 148 |
"Load Name": "Read exact text.",
|
| 149 |
"AF": "Read the FULL text exactly as printed.",
|
| 150 |
"AT": "Read the FULL text exactly as printed.",
|
| 151 |
-
"Breaking Capacity": "Read the FULL text exactly as printed."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
}
|
| 153 |
|
| 154 |
r = client.chat.completions.create(
|
| 155 |
model="gpt-5.2",
|
| 156 |
messages=[
|
| 157 |
-
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
{"role":"user","content":[
|
| 159 |
{"type":"text","text":rules[label]},
|
| 160 |
{"type":"image_url","image_url":{"url":f"data:image/jpeg;base64,{b64}"}}
|
|
@@ -162,19 +392,22 @@ def gpt_single_ocr(label, img):
|
|
| 162 |
],
|
| 163 |
temperature=0
|
| 164 |
)
|
|
|
|
| 165 |
return r.choices[0].message.content.strip()
|
| 166 |
|
|
|
|
| 167 |
# ================= OCR CORE =================
|
| 168 |
|
| 169 |
def gpt_ocr(label, img):
|
| 170 |
|
| 171 |
-
# ================= MANUFACTURER ================
|
| 172 |
if label in ["Manufacture Name","Load Name"]:
|
| 173 |
img = enhance(img)
|
| 174 |
save_crop(label, img)
|
| 175 |
t = gpt_single_ocr(label, img)
|
| 176 |
return normalize_manufacturer(t) if label=="Manufacture Name" else remove_spaces_only(t)
|
| 177 |
|
|
|
|
| 178 |
if label == "Breaking Capacity":
|
| 179 |
img = enhance_breaking_capacity(img)
|
| 180 |
t = gpt_single_ocr(label, img)
|
|
@@ -184,14 +417,19 @@ def gpt_ocr(label, img):
|
|
| 184 |
|
| 185 |
|
| 186 |
# ========= ROTATION BASED LABELS =========
|
|
|
|
| 187 |
best_text = ""
|
| 188 |
best_score = -1
|
| 189 |
best_img = None
|
| 190 |
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
| 194 |
|
|
|
|
|
|
|
| 195 |
|
| 196 |
try:
|
| 197 |
t = gpt_single_ocr(label, rimg)
|
|
@@ -248,6 +486,26 @@ def gpt_ocr(label, img):
|
|
| 248 |
continue
|
| 249 |
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
# Track best candidate
|
| 252 |
if score > best_score:
|
| 253 |
best_score = score
|
|
@@ -262,7 +520,6 @@ def gpt_ocr(label, img):
|
|
| 262 |
return ""
|
| 263 |
|
| 264 |
|
| 265 |
-
|
| 266 |
# ================= EXCEL VERIFICATION =================
|
| 267 |
|
| 268 |
def normalize_header(s):
|
|
@@ -306,23 +563,13 @@ def verify_excel(excel, det):
|
|
| 306 |
|
| 307 |
if hdr is None:
|
| 308 |
return pd.DataFrame([
|
| 309 |
-
["Excel", "", "エラー", "
|
| 310 |
], columns=["仕様","検出値","Excelに存在?","備考"])
|
| 311 |
|
| 312 |
df = raw.iloc[hdr+1:].copy()
|
| 313 |
df.columns = raw.iloc[hdr]
|
| 314 |
df.dropna(how="all", inplace=True)
|
| 315 |
|
| 316 |
-
def normalize_header(s):
|
| 317 |
-
return str(s).replace("\n","").replace(" ","")
|
| 318 |
-
|
| 319 |
-
def find_column(df, keys):
|
| 320 |
-
for c in df.columns:
|
| 321 |
-
for k in keys:
|
| 322 |
-
if k in normalize_header(c):
|
| 323 |
-
return c
|
| 324 |
-
return None
|
| 325 |
-
|
| 326 |
ccol = find_column(df, ["回路番号","回路"])
|
| 327 |
|
| 328 |
if ccol is None:
|
|
@@ -343,10 +590,26 @@ def verify_excel(excel, det):
|
|
| 343 |
|
| 344 |
rows = []
|
| 345 |
|
|
|
|
| 346 |
for k, jp in SPEC_JP.items():
|
| 347 |
|
| 348 |
detected_value = det.get(k, "").strip()
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
if col is None:
|
| 352 |
rows.append([
|
|
@@ -359,12 +622,19 @@ def verify_excel(excel, det):
|
|
| 359 |
|
| 360 |
excel_value = str(target[col])
|
| 361 |
|
| 362 |
-
if k=="Manufacture Name":
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 367 |
|
|
|
|
| 368 |
if not detected_value:
|
| 369 |
rows.append([
|
| 370 |
jp,
|
|
@@ -386,11 +656,10 @@ def verify_excel(excel, det):
|
|
| 386 |
return pd.DataFrame(rows,columns=["仕様","検出値","Excelに存在?","備考"])
|
| 387 |
|
| 388 |
|
| 389 |
-
|
| 390 |
-
|
| 391 |
# ================= PIPELINE & UI =================
|
| 392 |
|
| 393 |
-
def bbox_area(p):
|
|
|
|
| 394 |
|
| 395 |
def run_pipeline(image, excel):
|
| 396 |
|
|
@@ -402,7 +671,9 @@ def run_pipeline(image, excel):
|
|
| 402 |
return None, pd.DataFrame(), pd.DataFrame(), None, \
|
| 403 |
"⚠️ **Please upload the breaker panel image before running verification.**"
|
| 404 |
|
|
|
|
| 405 |
img = prepare_for_roboflow(image)
|
|
|
|
| 406 |
preds = model.predict(
|
| 407 |
img,
|
| 408 |
confidence=int(CONF_THRESHOLD*100),
|
|
@@ -410,37 +681,68 @@ def run_pipeline(image, excel):
|
|
| 410 |
).json()["predictions"]
|
| 411 |
|
| 412 |
vis = img.copy()
|
| 413 |
-
det={}
|
| 414 |
-
best_boxes={}
|
| 415 |
|
|
|
|
| 416 |
for p in preds:
|
| 417 |
-
|
|
|
|
|
|
|
| 418 |
if lab not in best_boxes:
|
| 419 |
-
best_boxes[lab]=p
|
| 420 |
else:
|
| 421 |
-
if lab=="Circuit Name":
|
| 422 |
-
if bbox_area(p)<bbox_area(best_boxes[lab]):
|
| 423 |
-
best_boxes[lab]=p
|
| 424 |
else:
|
| 425 |
-
if p["confidence"]>best_boxes[lab]["confidence"]:
|
| 426 |
-
best_boxes[lab]=p
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 427 |
|
| 428 |
-
|
| 429 |
-
x,y,w,h=map(int,[p["x"],p["y"],p["width"],p["height"]])
|
| 430 |
-
x1,y1,x2,y2=x-w//2,y-h//2,x+w//2,y+h//2
|
| 431 |
-
cv2.rectangle(vis,(x1,y1),(x2,y2),(0,255,0),2)
|
| 432 |
-
roi = upscale(crop_with_optional_expand(img,x1,y1,x2,y2,lab))
|
| 433 |
-
det[lab]=gpt_ocr(lab,roi)
|
| 434 |
|
| 435 |
-
|
| 436 |
-
|
|
|
|
| 437 |
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
|
| 443 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 444 |
|
| 445 |
apple_dark_pink_css = """
|
| 446 |
@import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600&display=swap');
|
|
@@ -491,34 +793,6 @@ input, textarea, select {
|
|
| 491 |
}
|
| 492 |
"""
|
| 493 |
|
| 494 |
-
import gradio as gr
|
| 495 |
-
|
| 496 |
-
apple_dark_pink_css = """
|
| 497 |
-
@import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600&display=swap');
|
| 498 |
-
.gradio-container {
|
| 499 |
-
background: #0f1115;
|
| 500 |
-
font-family: 'Outfit', sans-serif;
|
| 501 |
-
}
|
| 502 |
-
h1 { color: #f9fafb; font-weight: 600; }
|
| 503 |
-
h2, h3 { color: #e5e7eb; font-weight: 500; }
|
| 504 |
-
.gr-box {
|
| 505 |
-
background: #161a22;
|
| 506 |
-
border-radius: 16px;
|
| 507 |
-
padding: 12px;
|
| 508 |
-
}
|
| 509 |
-
button.primary {
|
| 510 |
-
background: #f472b6 !important;
|
| 511 |
-
color: #020617 !important;
|
| 512 |
-
border-radius: 12px;
|
| 513 |
-
font-weight: 500;
|
| 514 |
-
}
|
| 515 |
-
button.primary:hover {
|
| 516 |
-
background: #ec4899 !important;
|
| 517 |
-
}
|
| 518 |
-
input, textarea, select {
|
| 519 |
-
border-radius: 10px !important;
|
| 520 |
-
}
|
| 521 |
-
"""
|
| 522 |
with gr.Blocks(
|
| 523 |
theme=gr.themes.Soft(primary_hue="pink"),
|
| 524 |
css=apple_dark_pink_css
|
|
@@ -545,4 +819,4 @@ with gr.Blocks(
|
|
| 545 |
[img_out, t1, t2, f, status_msg]
|
| 546 |
)
|
| 547 |
|
| 548 |
-
demo.launch()
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torchvision import models, transforms
|
| 4 |
+
import albumentations as A
|
| 5 |
+
from albumentations.pytorch import ToTensorV2
|
| 6 |
import os, cv2, re, base64
|
| 7 |
import numpy as np
|
| 8 |
import pandas as pd
|
| 9 |
import gradio as gr
|
| 10 |
+
from difflib import get_close_matches
|
| 11 |
from roboflow import Roboflow
|
| 12 |
from openai import OpenAI
|
| 13 |
from openpyxl import load_workbook
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 17 |
+
import cv2
|
| 18 |
|
| 19 |
# ================= CONFIG =================
|
| 20 |
|
| 21 |
ROBOFLOW_API_KEY = "uP19IAi98TqwLvHmNB8V"
|
| 22 |
ROBOFLOW_PROJECT = "braker3"
|
| 23 |
+
ROBOFLOW_VERSION = 12
|
| 24 |
CONF_THRESHOLD = 0.35
|
| 25 |
IOU_THRESHOLD = 0.4
|
| 26 |
|
|
|
|
| 29 |
rf = Roboflow(api_key=ROBOFLOW_API_KEY)
|
| 30 |
model = rf.workspace().project(ROBOFLOW_PROJECT).version(ROBOFLOW_VERSION).model
|
| 31 |
|
| 32 |
+
CROP_DIR = "cropped_labels"
|
| 33 |
os.makedirs(CROP_DIR, exist_ok=True)
|
| 34 |
|
| 35 |
+
# ================= CLASSIFIER =================
|
| 36 |
+
|
| 37 |
+
CLASS_NAMES = ['BB', 'FF', 'P']
|
| 38 |
+
|
| 39 |
+
classifier_model = models.efficientnet_b0(weights=None)
|
| 40 |
+
in_features = classifier_model.classifier[1].in_features
|
| 41 |
+
|
| 42 |
+
classifier_model.classifier[1] = nn.Sequential(
|
| 43 |
+
nn.Dropout(p=0.3, inplace=True),
|
| 44 |
+
nn.Linear(in_features, 3)
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
classifier_model.load_state_dict(torch.load('breaker_classifier.pth', map_location='cpu'))
|
| 48 |
+
classifier_model.eval()
|
| 49 |
+
|
| 50 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 51 |
+
classifier_model.to(device)
|
| 52 |
+
|
| 53 |
+
# transform
|
| 54 |
+
type_transform = A.Compose([
|
| 55 |
+
A.Resize(224, 224),
|
| 56 |
+
A.Normalize(mean=(0.485, 0.456, 0.406),
|
| 57 |
+
std=(0.229, 0.224, 0.225)),
|
| 58 |
+
ToTensorV2(),
|
| 59 |
+
])
|
| 60 |
+
|
| 61 |
+
def predict_breaker_type(image):
|
| 62 |
+
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 63 |
+
|
| 64 |
+
augmented = type_transform(image=img)
|
| 65 |
+
tensor = augmented['image'].unsqueeze(0).to(device)
|
| 66 |
+
|
| 67 |
+
with torch.no_grad():
|
| 68 |
+
outputs = classifier_model(tensor)
|
| 69 |
+
probs = torch.softmax(outputs, dim=1)
|
| 70 |
+
conf, pred = torch.max(probs, 1)
|
| 71 |
+
|
| 72 |
+
return CLASS_NAMES[pred.item()], float(conf.item())
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# ================= EQUIPMENT TYPE CLASSIFIER =================
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ================= DEVICE =================
|
| 80 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 81 |
+
|
| 82 |
+
# ================= CLASS NAMES (MUST MATCH TRAINING) =================
|
| 83 |
+
EQUIPMENT_CLASS_NAMES = ['ACB', 'E', 'EM', 'EMDU', 'EMMDU', 'M', 'MDU']
|
| 84 |
+
|
| 85 |
+
# ================= TRANSFORM (EXACT SAME AS TRAINING) =================
|
| 86 |
+
val_transform = transforms.Compose([
|
| 87 |
+
transforms.Resize((224, 224)),
|
| 88 |
+
transforms.ToTensor(),
|
| 89 |
+
transforms.Normalize([0.485, 0.456, 0.406],
|
| 90 |
+
[0.229, 0.224, 0.225])
|
| 91 |
+
])
|
| 92 |
+
|
| 93 |
+
# ================= LOAD MODEL =================
|
| 94 |
+
equipment_model = models.efficientnet_b0(pretrained=False)
|
| 95 |
+
|
| 96 |
+
# IMPORTANT: same classifier as training
|
| 97 |
+
equipment_model.classifier[1] = nn.Linear(
|
| 98 |
+
equipment_model.classifier[1].in_features,
|
| 99 |
+
len(EQUIPMENT_CLASS_NAMES)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# Load weights
|
| 103 |
+
equipment_model.load_state_dict(
|
| 104 |
+
torch.load("efficientnet_breaker.pth", map_location=device)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
equipment_model = equipment_model.to(device)
|
| 108 |
+
equipment_model.eval()
|
| 109 |
+
|
| 110 |
+
# ================= PREDICTION FUNCTION =================
|
| 111 |
+
def predict_equipment_type(image):
|
| 112 |
+
"""
|
| 113 |
+
Input: OpenCV image (BGR)
|
| 114 |
+
Output:
|
| 115 |
+
best_class -> predicted label
|
| 116 |
+
best_conf -> confidence
|
| 117 |
+
prob_dict -> all class probabilities
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
# ---- Convert OpenCV → PIL (CRITICAL STEP) ----
|
| 121 |
+
img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 122 |
+
img = Image.fromarray(img)
|
| 123 |
+
|
| 124 |
+
# ---- Apply SAME transform as training ----
|
| 125 |
+
img = val_transform(img).unsqueeze(0).to(device)
|
| 126 |
+
|
| 127 |
+
# ---- Inference ----
|
| 128 |
+
with torch.no_grad():
|
| 129 |
+
outputs = equipment_model(img)
|
| 130 |
+
probs = F.softmax(outputs, dim=1)
|
| 131 |
+
|
| 132 |
+
# ---- Convert probabilities ----
|
| 133 |
+
prob_dict = {
|
| 134 |
+
EQUIPMENT_CLASS_NAMES[i]: float(probs[0][i].item())
|
| 135 |
+
for i in range(len(EQUIPMENT_CLASS_NAMES))
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
# ---- Get best class ----
|
| 139 |
+
pred_idx = torch.argmax(probs, dim=1).item()
|
| 140 |
+
best_class = EQUIPMENT_CLASS_NAMES[pred_idx]
|
| 141 |
+
best_conf = prob_dict[best_class]
|
| 142 |
+
|
| 143 |
+
return best_class, best_conf, prob_dict
|
| 144 |
+
|
| 145 |
# ================= CONSTANTS =================
|
| 146 |
|
| 147 |
CIRCUIT_PATTERN = r"(?:\d+L\d+-\d+|S\d+)"
|
|
|
|
| 150 |
|
| 151 |
VALID_AF_VALUES = {"50","63","100","125","160","250","400","630"}
|
| 152 |
|
| 153 |
+
VALID_OPTIONS = [
|
| 154 |
+
"AL",
|
| 155 |
+
"AX",
|
| 156 |
+
"PAL",
|
| 157 |
+
"EAL",
|
| 158 |
+
"SHT",
|
| 159 |
+
"AL+AX",
|
| 160 |
+
"AL+AX+PAL+EAL"
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
# ================= LABEL MAPPING (ADD HERE) =================
|
| 164 |
+
LABEL_MAP = {
|
| 165 |
+
"manufacture name": "Manufacture Name",
|
| 166 |
+
"load name": "Load Name",
|
| 167 |
+
"breaking capacity": "Breaking Capacity",
|
| 168 |
+
"af": "AF",
|
| 169 |
+
"at": "AT",
|
| 170 |
+
"option": "Option",
|
| 171 |
+
"circuit name": "Circuit Name"
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
SPEC_JP = {
|
| 175 |
"Manufacture Name": "メーカー",
|
| 176 |
"Circuit Name": "回路番号",
|
| 177 |
"Load Name": "負荷名称",
|
| 178 |
"Breaking Capacity": "遮断容量",
|
| 179 |
"AF": "フレーム(AF)",
|
| 180 |
+
"AT": "トリップ(AT)",
|
| 181 |
+
"Option": "オプション",
|
| 182 |
+
"Type": "タイプ",
|
| 183 |
+
"Equipment Type": "機器種別"
|
| 184 |
}
|
| 185 |
|
| 186 |
MANUFACTURER_JP_MAP = {
|
|
|
|
| 191 |
"LS ELECTRIC": "LS ELECTRIC"
|
| 192 |
}
|
| 193 |
|
| 194 |
+
# ✅ STRICT WHITELIST
|
| 195 |
KNOWN_MANUFACTURERS = {
|
| 196 |
"MITSUBISHI ELECTRIC",
|
| 197 |
"SIEMENS",
|
|
|
|
| 220 |
return img[max(0,y1):min(h,y2), max(0,x1):min(w,x2)]
|
| 221 |
|
| 222 |
def upscale(img):
|
| 223 |
+
return cv2.resize(img, None, fx=1.5, fy=1.5, interpolation=cv2.INTER_CUBIC)
|
| 224 |
|
| 225 |
def rotate_image(img, a):
|
| 226 |
if a == 90: return cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
|
|
|
|
| 254 |
sharp = cv2.addWeighted(img, 1.5, blur, -0.5, 0)
|
| 255 |
return sharp
|
| 256 |
|
| 257 |
+
def enhance_option(img):
|
| 258 |
+
# upscale for tiny text
|
| 259 |
+
img = cv2.resize(img, None, fx=2.2, fy=2.2, interpolation=cv2.INTER_CUBIC)
|
| 260 |
+
|
| 261 |
+
# --- HSV boost (brightness + saturation) ---
|
| 262 |
+
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
| 263 |
+
|
| 264 |
+
# brighten dark labels
|
| 265 |
+
hsv[:,:,2] = cv2.add(hsv[:,:,2], 55)
|
| 266 |
+
|
| 267 |
+
# increase saturation so faded letters appear
|
| 268 |
+
hsv[:,:,1] = cv2.add(hsv[:,:,1], 35)
|
| 269 |
+
|
| 270 |
+
img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
|
| 271 |
+
|
| 272 |
+
# --- gamma correction (critical for dark plastic labels) ---
|
| 273 |
+
gamma = 1.6
|
| 274 |
+
inv = 1.0 / gamma
|
| 275 |
+
table = np.array([(i/255.0)**inv * 255 for i in range(256)]).astype("uint8")
|
| 276 |
+
img = cv2.LUT(img, table)
|
| 277 |
+
|
| 278 |
+
# grayscale for OCR
|
| 279 |
+
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 280 |
+
|
| 281 |
+
# denoise
|
| 282 |
+
gray = cv2.fastNlMeansDenoising(gray, h=10)
|
| 283 |
+
|
| 284 |
+
# strong CLAHE (local contrast)
|
| 285 |
+
clahe = cv2.createCLAHE(4.0, (8,8))
|
| 286 |
+
gray = clahe.apply(gray)
|
| 287 |
+
|
| 288 |
+
# strong sharpen → reveals dark letters
|
| 289 |
+
blur = cv2.GaussianBlur(gray, (0,0), 1.3)
|
| 290 |
+
sharp = cv2.addWeighted(gray, 1.9, blur, -0.9, 0)
|
| 291 |
+
|
| 292 |
+
return cv2.cvtColor(sharp, cv2.COLOR_GRAY2BGR)
|
| 293 |
+
|
| 294 |
+
def normalize_option_text(text):
|
| 295 |
+
if not text:
|
| 296 |
+
return ""
|
| 297 |
+
|
| 298 |
+
t = text.upper().strip()
|
| 299 |
+
|
| 300 |
+
# ---- orientation / OCR confusion map
|
| 301 |
+
confusion_map = {
|
| 302 |
+
"7": "A", # upside-down A
|
| 303 |
+
"V": "L", # upside-down L
|
| 304 |
+
"1": "L",
|
| 305 |
+
"|": "L",
|
| 306 |
+
"I": "L",
|
| 307 |
+
"Y": "L",
|
| 308 |
+
"4": "A"
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
fixed = "".join(confusion_map.get(c, c) for c in t)
|
| 312 |
+
|
| 313 |
+
# direct match
|
| 314 |
+
if fixed in VALID_OPTIONS:
|
| 315 |
+
return fixed
|
| 316 |
+
|
| 317 |
+
# fuzzy match (core industrial step)
|
| 318 |
+
m = get_close_matches(fixed, VALID_OPTIONS, n=1, cutoff=0.4)
|
| 319 |
+
if m:
|
| 320 |
+
return m[0]
|
| 321 |
+
|
| 322 |
+
return fixed
|
| 323 |
+
|
| 324 |
def save_crop(label, img):
|
| 325 |
cv2.imwrite(os.path.join(CROP_DIR, f"{label}.jpg"), img)
|
| 326 |
|
|
|
|
| 354 |
|
| 355 |
def gpt_single_ocr(label, img):
|
| 356 |
b64 = img_to_base64(img)
|
| 357 |
+
|
| 358 |
rules = {
|
| 359 |
"Manufacture Name": "Read manufacturer name in English only.",
|
| 360 |
"Circuit Name": "Read the FULL text exactly as printed.",
|
| 361 |
"Load Name": "Read exact text.",
|
| 362 |
"AF": "Read the FULL text exactly as printed.",
|
| 363 |
"AT": "Read the FULL text exactly as printed.",
|
| 364 |
+
"Breaking Capacity": "Read the FULL text exactly as printed.",
|
| 365 |
+
|
| 366 |
+
# ⭐ NEW — OPTION PROMPT
|
| 367 |
+
"Option": (
|
| 368 |
+
"Industrial breaker OPTION label. "
|
| 369 |
+
"Text is short (1–10 characters). "
|
| 370 |
+
"Image may be dark, faint, small, or rotated. "
|
| 371 |
+
"Return ONLY the exact printed text. "
|
| 372 |
+
"If unreadable return empty."
|
| 373 |
+
)
|
| 374 |
}
|
| 375 |
|
| 376 |
r = client.chat.completions.create(
|
| 377 |
model="gpt-5.2",
|
| 378 |
messages=[
|
| 379 |
+
{
|
| 380 |
+
"role": "system",
|
| 381 |
+
"content": (
|
| 382 |
+
"You are an industrial electrical label OCR engine. "
|
| 383 |
+
"Extract text exactly as printed. "
|
| 384 |
+
"Text may be tiny, dark, or rotated. "
|
| 385 |
+
"Return only text."
|
| 386 |
+
)
|
| 387 |
+
},
|
| 388 |
{"role":"user","content":[
|
| 389 |
{"type":"text","text":rules[label]},
|
| 390 |
{"type":"image_url","image_url":{"url":f"data:image/jpeg;base64,{b64}"}}
|
|
|
|
| 392 |
],
|
| 393 |
temperature=0
|
| 394 |
)
|
| 395 |
+
|
| 396 |
return r.choices[0].message.content.strip()
|
| 397 |
|
| 398 |
+
|
| 399 |
# ================= OCR CORE =================
|
| 400 |
|
| 401 |
def gpt_ocr(label, img):
|
| 402 |
|
| 403 |
+
# ================= MANUFACTURER =================
|
| 404 |
if label in ["Manufacture Name","Load Name"]:
|
| 405 |
img = enhance(img)
|
| 406 |
save_crop(label, img)
|
| 407 |
t = gpt_single_ocr(label, img)
|
| 408 |
return normalize_manufacturer(t) if label=="Manufacture Name" else remove_spaces_only(t)
|
| 409 |
|
| 410 |
+
# ================= BREAKING CAPACITY =================
|
| 411 |
if label == "Breaking Capacity":
|
| 412 |
img = enhance_breaking_capacity(img)
|
| 413 |
t = gpt_single_ocr(label, img)
|
|
|
|
| 417 |
|
| 418 |
|
| 419 |
# ========= ROTATION BASED LABELS =========
|
| 420 |
+
|
| 421 |
best_text = ""
|
| 422 |
best_score = -1
|
| 423 |
best_img = None
|
| 424 |
|
| 425 |
+
# ⭐ IMPORTANT — choose preprocessing based on label
|
| 426 |
+
if label == "Option":
|
| 427 |
+
base = enhance_option(img)
|
| 428 |
+
else:
|
| 429 |
+
base = enhance(img)
|
| 430 |
|
| 431 |
+
for a in [0, 90, 180, 270]:
|
| 432 |
+
rimg = rotate_image(base, a)
|
| 433 |
|
| 434 |
try:
|
| 435 |
t = gpt_single_ocr(label, rimg)
|
|
|
|
| 486 |
continue
|
| 487 |
|
| 488 |
|
| 489 |
+
# ================= OPTION ⭐ NEW =================
|
| 490 |
+
elif label == "Option":
|
| 491 |
+
|
| 492 |
+
# ⭐ IMPORTANT — normalize BEFORE scoring
|
| 493 |
+
candidate = normalize_option_text(clean)
|
| 494 |
+
|
| 495 |
+
# Option text usually short
|
| 496 |
+
if len(candidate) <= 10:
|
| 497 |
+
score += 40
|
| 498 |
+
|
| 499 |
+
# prefer valid option hits (strong signal)
|
| 500 |
+
if candidate in VALID_OPTIONS:
|
| 501 |
+
score += 120
|
| 502 |
+
|
| 503 |
+
# prefer alphabetic (real options are alphabetic)
|
| 504 |
+
if re.search(r"[A-Za-z]", candidate):
|
| 505 |
+
score += 20
|
| 506 |
+
|
| 507 |
+
score += len(candidate)
|
| 508 |
+
|
| 509 |
# Track best candidate
|
| 510 |
if score > best_score:
|
| 511 |
best_score = score
|
|
|
|
| 520 |
return ""
|
| 521 |
|
| 522 |
|
|
|
|
| 523 |
# ================= EXCEL VERIFICATION =================
|
| 524 |
|
| 525 |
def normalize_header(s):
|
|
|
|
| 563 |
|
| 564 |
if hdr is None:
|
| 565 |
return pd.DataFrame([
|
| 566 |
+
["Excel", "", "エラー", "ヘッダー行が見つかりません。"]
|
| 567 |
], columns=["仕様","検出値","Excelに存在?","備考"])
|
| 568 |
|
| 569 |
df = raw.iloc[hdr+1:].copy()
|
| 570 |
df.columns = raw.iloc[hdr]
|
| 571 |
df.dropna(how="all", inplace=True)
|
| 572 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
ccol = find_column(df, ["回路番号","回路"])
|
| 574 |
|
| 575 |
if ccol is None:
|
|
|
|
| 590 |
|
| 591 |
rows = []
|
| 592 |
|
| 593 |
+
# ⭐ LOOP THROUGH ALL FIELDS (NOW INCLUDES TYPE)
|
| 594 |
for k, jp in SPEC_JP.items():
|
| 595 |
|
| 596 |
detected_value = det.get(k, "").strip()
|
| 597 |
+
|
| 598 |
+
# ---- column search keys ----
|
| 599 |
+
keys = [jp.replace("(","").replace(")",""), jp[:2]]
|
| 600 |
+
|
| 601 |
+
# ⭐ OPTION SUPPORT
|
| 602 |
+
if k == "Option":
|
| 603 |
+
keys += ["オプション", "オプシ", "ション"]
|
| 604 |
+
|
| 605 |
+
# ⭐ TYPE SUPPORT (NEW 🔥)
|
| 606 |
+
if k == "Type":
|
| 607 |
+
keys += ["タイプ"]
|
| 608 |
+
|
| 609 |
+
if k == "Equipment Type":
|
| 610 |
+
keys += ["機器 種別"]
|
| 611 |
+
|
| 612 |
+
col = find_column(df, keys)
|
| 613 |
|
| 614 |
if col is None:
|
| 615 |
rows.append([
|
|
|
|
| 622 |
|
| 623 |
excel_value = str(target[col])
|
| 624 |
|
| 625 |
+
if k == "Manufacture Name":
|
| 626 |
+
detected_jp = MANUFACTURER_JP_MAP.get(detected_value, detected_value)
|
| 627 |
+
excel_jp = excel_value
|
| 628 |
+
ok = normalize_for_compare(detected_jp) == normalize_for_compare(excel_jp)
|
| 629 |
+
rows.append([
|
| 630 |
+
jp,
|
| 631 |
+
detected_jp, # 👈 show Japanese here
|
| 632 |
+
"YES" if ok else "NO",
|
| 633 |
+
"" if ok else f"Excel値: {excel_jp}"
|
| 634 |
+
])
|
| 635 |
+
continue
|
| 636 |
|
| 637 |
+
# not detected in panel
|
| 638 |
if not detected_value:
|
| 639 |
rows.append([
|
| 640 |
jp,
|
|
|
|
| 656 |
return pd.DataFrame(rows,columns=["仕様","検出値","Excelに存在?","備考"])
|
| 657 |
|
| 658 |
|
|
|
|
|
|
|
| 659 |
# ================= PIPELINE & UI =================
|
| 660 |
|
| 661 |
+
def bbox_area(p):
|
| 662 |
+
return p["width"] * p["height"]
|
| 663 |
|
| 664 |
def run_pipeline(image, excel):
|
| 665 |
|
|
|
|
| 671 |
return None, pd.DataFrame(), pd.DataFrame(), None, \
|
| 672 |
"⚠️ **Please upload the breaker panel image before running verification.**"
|
| 673 |
|
| 674 |
+
# ================= DETECTION =================
|
| 675 |
img = prepare_for_roboflow(image)
|
| 676 |
+
|
| 677 |
preds = model.predict(
|
| 678 |
img,
|
| 679 |
confidence=int(CONF_THRESHOLD*100),
|
|
|
|
| 681 |
).json()["predictions"]
|
| 682 |
|
| 683 |
vis = img.copy()
|
| 684 |
+
det = {}
|
| 685 |
+
best_boxes = {}
|
| 686 |
|
| 687 |
+
# ================= SELECT BEST BOX =================
|
| 688 |
for p in preds:
|
| 689 |
+
raw_lab = p["class"]
|
| 690 |
+
lab = LABEL_MAP.get(raw_lab.lower(), raw_lab)
|
| 691 |
+
|
| 692 |
if lab not in best_boxes:
|
| 693 |
+
best_boxes[lab] = p
|
| 694 |
else:
|
| 695 |
+
if lab == "Circuit Name":
|
| 696 |
+
if bbox_area(p) < bbox_area(best_boxes[lab]):
|
| 697 |
+
best_boxes[lab] = p
|
| 698 |
else:
|
| 699 |
+
if p["confidence"] > best_boxes[lab]["confidence"]:
|
| 700 |
+
best_boxes[lab] = p
|
| 701 |
+
|
| 702 |
+
# ================= PARALLEL OCR =================
|
| 703 |
+
def process_label(item):
|
| 704 |
+
lab, p = item
|
| 705 |
+
|
| 706 |
+
x, y, w, h = map(int, [p["x"], p["y"], p["width"], p["height"]])
|
| 707 |
+
x1, y1, x2, y2 = x - w//2, y - h//2, x + w//2, y + h//2
|
| 708 |
+
|
| 709 |
+
roi = upscale(crop_with_optional_expand(img, x1, y1, x2, y2, lab))
|
| 710 |
+
|
| 711 |
+
try:
|
| 712 |
+
value = gpt_ocr(lab, roi)
|
| 713 |
+
except Exception as e:
|
| 714 |
+
value = ""
|
| 715 |
+
print(f"OCR Error for {lab}: {e}")
|
| 716 |
|
| 717 |
+
return lab, value, (x1, y1, x2, y2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 718 |
|
| 719 |
+
# 🔥 Parallel execution (IMPORTANT)
|
| 720 |
+
with ThreadPoolExecutor(max_workers=5) as executor:
|
| 721 |
+
results = list(executor.map(process_label, best_boxes.items()))
|
| 722 |
|
| 723 |
+
# ================= COLLECT RESULTS =================
|
| 724 |
+
for lab, value, (x1, y1, x2, y2) in results:
|
| 725 |
+
det[lab] = value
|
| 726 |
+
cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 727 |
|
| 728 |
+
# ================= TYPE CLASSIFICATION =================
|
| 729 |
+
pred_type, conf = predict_breaker_type(img)
|
| 730 |
+
det["Type"] = pred_type
|
| 731 |
+
|
| 732 |
+
eq_type, eq_conf, _ = predict_equipment_type(img)
|
| 733 |
+
det["Equipment Type"] = eq_type
|
| 734 |
+
|
| 735 |
+
# ================= DATAFRAMES =================
|
| 736 |
+
ocr_df = pd.DataFrame(det.items(), columns=["Field", "Extracted Text"])
|
| 737 |
+
verify_df = verify_excel(excel, det)
|
| 738 |
+
|
| 739 |
+
# ================= SAVE =================
|
| 740 |
+
out = "verification_result.xlsx"
|
| 741 |
+
with pd.ExcelWriter(out, engine="openpyxl") as w:
|
| 742 |
+
ocr_df.to_excel(w, "OCR_Output", index=False)
|
| 743 |
+
verify_df.to_excel(w, "Verification", index=False)
|
| 744 |
+
|
| 745 |
+
return vis, ocr_df, verify_df, out, ""
|
| 746 |
|
| 747 |
apple_dark_pink_css = """
|
| 748 |
@import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600&display=swap');
|
|
|
|
| 793 |
}
|
| 794 |
"""
|
| 795 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 796 |
with gr.Blocks(
|
| 797 |
theme=gr.themes.Soft(primary_hue="pink"),
|
| 798 |
css=apple_dark_pink_css
|
|
|
|
| 819 |
[img_out, t1, t2, f, status_msg]
|
| 820 |
)
|
| 821 |
|
| 822 |
+
demo.launch()
|