Swaroop05 commited on
Commit
96f4ea7
·
verified ·
1 Parent(s): 699e9c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +356 -82
app.py CHANGED
@@ -1,16 +1,26 @@
 
 
 
 
 
1
  import os, cv2, re, base64
2
  import numpy as np
3
  import pandas as pd
4
  import gradio as gr
 
5
  from roboflow import Roboflow
6
  from openai import OpenAI
7
  from openpyxl import load_workbook
 
 
 
 
8
 
9
  # ================= CONFIG =================
10
 
11
  ROBOFLOW_API_KEY = "uP19IAi98TqwLvHmNB8V"
12
  ROBOFLOW_PROJECT = "braker3"
13
- ROBOFLOW_VERSION = 10
14
  CONF_THRESHOLD = 0.35
15
  IOU_THRESHOLD = 0.4
16
 
@@ -19,9 +29,119 @@ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
19
  rf = Roboflow(api_key=ROBOFLOW_API_KEY)
20
  model = rf.workspace().project(ROBOFLOW_PROJECT).version(ROBOFLOW_VERSION).model
21
 
22
- CROP_DIR = "/data/cropped_labels"
23
  os.makedirs(CROP_DIR, exist_ok=True)
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # ================= CONSTANTS =================
26
 
27
  CIRCUIT_PATTERN = r"(?:\d+L\d+-\d+|S\d+)"
@@ -30,13 +150,37 @@ DEFAULT_BREAKING_CAPACITY = "85"
30
 
31
  VALID_AF_VALUES = {"50","63","100","125","160","250","400","630"}
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  SPEC_JP = {
34
  "Manufacture Name": "メーカー",
35
  "Circuit Name": "回路番号",
36
  "Load Name": "負荷名称",
37
  "Breaking Capacity": "遮断容量",
38
  "AF": "フレーム(AF)",
39
- "AT": "トリップ(AT)"
 
 
 
40
  }
41
 
42
  MANUFACTURER_JP_MAP = {
@@ -47,6 +191,7 @@ MANUFACTURER_JP_MAP = {
47
  "LS ELECTRIC": "LS ELECTRIC"
48
  }
49
 
 
50
  KNOWN_MANUFACTURERS = {
51
  "MITSUBISHI ELECTRIC",
52
  "SIEMENS",
@@ -75,7 +220,7 @@ def crop_with_optional_expand(img, x1, y1, x2, y2, label):
75
  return img[max(0,y1):min(h,y2), max(0,x1):min(w,x2)]
76
 
77
  def upscale(img):
78
- return cv2.resize(img, None, fx=2.5, fy=2.5, interpolation=cv2.INTER_CUBIC)
79
 
80
  def rotate_image(img, a):
81
  if a == 90: return cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
@@ -109,6 +254,73 @@ def enhance_AT(img):
109
  sharp = cv2.addWeighted(img, 1.5, blur, -0.5, 0)
110
  return sharp
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def save_crop(label, img):
113
  cv2.imwrite(os.path.join(CROP_DIR, f"{label}.jpg"), img)
114
 
@@ -142,19 +354,37 @@ def normalize_for_compare(t):
142
 
143
  def gpt_single_ocr(label, img):
144
  b64 = img_to_base64(img)
 
145
  rules = {
146
  "Manufacture Name": "Read manufacturer name in English only.",
147
  "Circuit Name": "Read the FULL text exactly as printed.",
148
  "Load Name": "Read exact text.",
149
  "AF": "Read the FULL text exactly as printed.",
150
  "AT": "Read the FULL text exactly as printed.",
151
- "Breaking Capacity": "Read the FULL text exactly as printed."
 
 
 
 
 
 
 
 
 
152
  }
153
 
154
  r = client.chat.completions.create(
155
  model="gpt-5.2",
156
  messages=[
157
- {"role":"system","content":"You are a strict OCR engine."},
 
 
 
 
 
 
 
 
158
  {"role":"user","content":[
159
  {"type":"text","text":rules[label]},
160
  {"type":"image_url","image_url":{"url":f"data:image/jpeg;base64,{b64}"}}
@@ -162,19 +392,22 @@ def gpt_single_ocr(label, img):
162
  ],
163
  temperature=0
164
  )
 
165
  return r.choices[0].message.content.strip()
166
 
 
167
  # ================= OCR CORE =================
168
 
169
  def gpt_ocr(label, img):
170
 
171
- # ================= MANUFACTURER ================
172
  if label in ["Manufacture Name","Load Name"]:
173
  img = enhance(img)
174
  save_crop(label, img)
175
  t = gpt_single_ocr(label, img)
176
  return normalize_manufacturer(t) if label=="Manufacture Name" else remove_spaces_only(t)
177
 
 
178
  if label == "Breaking Capacity":
179
  img = enhance_breaking_capacity(img)
180
  t = gpt_single_ocr(label, img)
@@ -184,14 +417,19 @@ def gpt_ocr(label, img):
184
 
185
 
186
  # ========= ROTATION BASED LABELS =========
 
187
  best_text = ""
188
  best_score = -1
189
  best_img = None
190
 
191
- base = enhance(img)
192
- for a in [0, 90]:
193
- rimg = rotate_image(base, a)
 
 
194
 
 
 
195
 
196
  try:
197
  t = gpt_single_ocr(label, rimg)
@@ -248,6 +486,26 @@ def gpt_ocr(label, img):
248
  continue
249
 
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  # Track best candidate
252
  if score > best_score:
253
  best_score = score
@@ -262,7 +520,6 @@ def gpt_ocr(label, img):
262
  return ""
263
 
264
 
265
-
266
  # ================= EXCEL VERIFICATION =================
267
 
268
  def normalize_header(s):
@@ -306,23 +563,13 @@ def verify_excel(excel, det):
306
 
307
  if hdr is None:
308
  return pd.DataFrame([
309
- ["Excel", "", "エラー", "項目が見つかりません。"]
310
  ], columns=["仕様","検出値","Excelに存在?","備考"])
311
 
312
  df = raw.iloc[hdr+1:].copy()
313
  df.columns = raw.iloc[hdr]
314
  df.dropna(how="all", inplace=True)
315
 
316
- def normalize_header(s):
317
- return str(s).replace("\n","").replace(" ","")
318
-
319
- def find_column(df, keys):
320
- for c in df.columns:
321
- for k in keys:
322
- if k in normalize_header(c):
323
- return c
324
- return None
325
-
326
  ccol = find_column(df, ["回路番号","回路"])
327
 
328
  if ccol is None:
@@ -343,10 +590,26 @@ def verify_excel(excel, det):
343
 
344
  rows = []
345
 
 
346
  for k, jp in SPEC_JP.items():
347
 
348
  detected_value = det.get(k, "").strip()
349
- col = find_column(df, [jp.replace("(","").replace(")",""), jp[:2]])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
  if col is None:
352
  rows.append([
@@ -359,12 +622,19 @@ def verify_excel(excel, det):
359
 
360
  excel_value = str(target[col])
361
 
362
- if k=="Manufacture Name":
363
- for eng,jpn in MANUFACTURER_JP_MAP.items():
364
- if jpn in excel_value:
365
- excel_value = eng
366
- break
 
 
 
 
 
 
367
 
 
368
  if not detected_value:
369
  rows.append([
370
  jp,
@@ -386,11 +656,10 @@ def verify_excel(excel, det):
386
  return pd.DataFrame(rows,columns=["仕様","検出値","Excelに存在?","備考"])
387
 
388
 
389
-
390
-
391
  # ================= PIPELINE & UI =================
392
 
393
- def bbox_area(p): return p["width"] * p["height"]
 
394
 
395
  def run_pipeline(image, excel):
396
 
@@ -402,7 +671,9 @@ def run_pipeline(image, excel):
402
  return None, pd.DataFrame(), pd.DataFrame(), None, \
403
  "⚠️ **Please upload the breaker panel image before running verification.**"
404
 
 
405
  img = prepare_for_roboflow(image)
 
406
  preds = model.predict(
407
  img,
408
  confidence=int(CONF_THRESHOLD*100),
@@ -410,37 +681,68 @@ def run_pipeline(image, excel):
410
  ).json()["predictions"]
411
 
412
  vis = img.copy()
413
- det={}
414
- best_boxes={}
415
 
 
416
  for p in preds:
417
- lab=p["class"]
 
 
418
  if lab not in best_boxes:
419
- best_boxes[lab]=p
420
  else:
421
- if lab=="Circuit Name":
422
- if bbox_area(p)<bbox_area(best_boxes[lab]):
423
- best_boxes[lab]=p
424
  else:
425
- if p["confidence"]>best_boxes[lab]["confidence"]:
426
- best_boxes[lab]=p
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
- for lab,p in best_boxes.items():
429
- x,y,w,h=map(int,[p["x"],p["y"],p["width"],p["height"]])
430
- x1,y1,x2,y2=x-w//2,y-h//2,x+w//2,y+h//2
431
- cv2.rectangle(vis,(x1,y1),(x2,y2),(0,255,0),2)
432
- roi = upscale(crop_with_optional_expand(img,x1,y1,x2,y2,lab))
433
- det[lab]=gpt_ocr(lab,roi)
434
 
435
- ocr_df=pd.DataFrame(det.items(),columns=["Field","Extracted Text"])
436
- verify_df=verify_excel(excel,det)
 
437
 
438
- out="verification_result.xlsx"
439
- with pd.ExcelWriter(out,engine="openpyxl") as w:
440
- ocr_df.to_excel(w,"OCR_Output",index=False)
441
- verify_df.to_excel(w,"Verification",index=False)
442
 
443
- return vis,ocr_df,verify_df,out, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
444
 
445
  apple_dark_pink_css = """
446
  @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600&display=swap');
@@ -491,34 +793,6 @@ input, textarea, select {
491
  }
492
  """
493
 
494
- import gradio as gr
495
-
496
- apple_dark_pink_css = """
497
- @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600&display=swap');
498
- .gradio-container {
499
- background: #0f1115;
500
- font-family: 'Outfit', sans-serif;
501
- }
502
- h1 { color: #f9fafb; font-weight: 600; }
503
- h2, h3 { color: #e5e7eb; font-weight: 500; }
504
- .gr-box {
505
- background: #161a22;
506
- border-radius: 16px;
507
- padding: 12px;
508
- }
509
- button.primary {
510
- background: #f472b6 !important;
511
- color: #020617 !important;
512
- border-radius: 12px;
513
- font-weight: 500;
514
- }
515
- button.primary:hover {
516
- background: #ec4899 !important;
517
- }
518
- input, textarea, select {
519
- border-radius: 10px !important;
520
- }
521
- """
522
  with gr.Blocks(
523
  theme=gr.themes.Soft(primary_hue="pink"),
524
  css=apple_dark_pink_css
@@ -545,4 +819,4 @@ with gr.Blocks(
545
  [img_out, t1, t2, f, status_msg]
546
  )
547
 
548
- demo.launch()
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
+ import albumentations as A
5
+ from albumentations.pytorch import ToTensorV2
6
  import os, cv2, re, base64
7
  import numpy as np
8
  import pandas as pd
9
  import gradio as gr
10
+ from difflib import get_close_matches
11
  from roboflow import Roboflow
12
  from openai import OpenAI
13
  from openpyxl import load_workbook
14
+ import torch.nn.functional as F
15
+ from PIL import Image
16
+ from concurrent.futures import ThreadPoolExecutor
17
+ import cv2
18
 
19
  # ================= CONFIG =================
20
 
21
  ROBOFLOW_API_KEY = "uP19IAi98TqwLvHmNB8V"
22
  ROBOFLOW_PROJECT = "braker3"
23
+ ROBOFLOW_VERSION = 12
24
  CONF_THRESHOLD = 0.35
25
  IOU_THRESHOLD = 0.4
26
 
 
29
  rf = Roboflow(api_key=ROBOFLOW_API_KEY)
30
  model = rf.workspace().project(ROBOFLOW_PROJECT).version(ROBOFLOW_VERSION).model
31
 
32
+ CROP_DIR = "cropped_labels"
33
  os.makedirs(CROP_DIR, exist_ok=True)
34
 
35
+ # ================= CLASSIFIER =================
36
+
37
+ CLASS_NAMES = ['BB', 'FF', 'P']
38
+
39
+ classifier_model = models.efficientnet_b0(weights=None)
40
+ in_features = classifier_model.classifier[1].in_features
41
+
42
+ classifier_model.classifier[1] = nn.Sequential(
43
+ nn.Dropout(p=0.3, inplace=True),
44
+ nn.Linear(in_features, 3)
45
+ )
46
+
47
+ classifier_model.load_state_dict(torch.load('breaker_classifier.pth', map_location='cpu'))
48
+ classifier_model.eval()
49
+
50
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
+ classifier_model.to(device)
52
+
53
+ # transform
54
+ type_transform = A.Compose([
55
+ A.Resize(224, 224),
56
+ A.Normalize(mean=(0.485, 0.456, 0.406),
57
+ std=(0.229, 0.224, 0.225)),
58
+ ToTensorV2(),
59
+ ])
60
+
61
+ def predict_breaker_type(image):
62
+ img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
63
+
64
+ augmented = type_transform(image=img)
65
+ tensor = augmented['image'].unsqueeze(0).to(device)
66
+
67
+ with torch.no_grad():
68
+ outputs = classifier_model(tensor)
69
+ probs = torch.softmax(outputs, dim=1)
70
+ conf, pred = torch.max(probs, 1)
71
+
72
+ return CLASS_NAMES[pred.item()], float(conf.item())
73
+
74
+
75
+ # ================= EQUIPMENT TYPE CLASSIFIER =================
76
+
77
+
78
+
79
+ # ================= DEVICE =================
80
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+
82
+ # ================= CLASS NAMES (MUST MATCH TRAINING) =================
83
+ EQUIPMENT_CLASS_NAMES = ['ACB', 'E', 'EM', 'EMDU', 'EMMDU', 'M', 'MDU']
84
+
85
+ # ================= TRANSFORM (EXACT SAME AS TRAINING) =================
86
+ val_transform = transforms.Compose([
87
+ transforms.Resize((224, 224)),
88
+ transforms.ToTensor(),
89
+ transforms.Normalize([0.485, 0.456, 0.406],
90
+ [0.229, 0.224, 0.225])
91
+ ])
92
+
93
+ # ================= LOAD MODEL =================
94
+ equipment_model = models.efficientnet_b0(pretrained=False)
95
+
96
+ # IMPORTANT: same classifier as training
97
+ equipment_model.classifier[1] = nn.Linear(
98
+ equipment_model.classifier[1].in_features,
99
+ len(EQUIPMENT_CLASS_NAMES)
100
+ )
101
+
102
+ # Load weights
103
+ equipment_model.load_state_dict(
104
+ torch.load("efficientnet_breaker.pth", map_location=device)
105
+ )
106
+
107
+ equipment_model = equipment_model.to(device)
108
+ equipment_model.eval()
109
+
110
+ # ================= PREDICTION FUNCTION =================
111
+ def predict_equipment_type(image):
112
+ """
113
+ Input: OpenCV image (BGR)
114
+ Output:
115
+ best_class -> predicted label
116
+ best_conf -> confidence
117
+ prob_dict -> all class probabilities
118
+ """
119
+
120
+ # ---- Convert OpenCV → PIL (CRITICAL STEP) ----
121
+ img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
122
+ img = Image.fromarray(img)
123
+
124
+ # ---- Apply SAME transform as training ----
125
+ img = val_transform(img).unsqueeze(0).to(device)
126
+
127
+ # ---- Inference ----
128
+ with torch.no_grad():
129
+ outputs = equipment_model(img)
130
+ probs = F.softmax(outputs, dim=1)
131
+
132
+ # ---- Convert probabilities ----
133
+ prob_dict = {
134
+ EQUIPMENT_CLASS_NAMES[i]: float(probs[0][i].item())
135
+ for i in range(len(EQUIPMENT_CLASS_NAMES))
136
+ }
137
+
138
+ # ---- Get best class ----
139
+ pred_idx = torch.argmax(probs, dim=1).item()
140
+ best_class = EQUIPMENT_CLASS_NAMES[pred_idx]
141
+ best_conf = prob_dict[best_class]
142
+
143
+ return best_class, best_conf, prob_dict
144
+
145
  # ================= CONSTANTS =================
146
 
147
  CIRCUIT_PATTERN = r"(?:\d+L\d+-\d+|S\d+)"
 
150
 
151
  VALID_AF_VALUES = {"50","63","100","125","160","250","400","630"}
152
 
153
+ VALID_OPTIONS = [
154
+ "AL",
155
+ "AX",
156
+ "PAL",
157
+ "EAL",
158
+ "SHT",
159
+ "AL+AX",
160
+ "AL+AX+PAL+EAL"
161
+ ]
162
+
163
+ # ================= LABEL MAPPING (ADD HERE) =================
164
+ LABEL_MAP = {
165
+ "manufacture name": "Manufacture Name",
166
+ "load name": "Load Name",
167
+ "breaking capacity": "Breaking Capacity",
168
+ "af": "AF",
169
+ "at": "AT",
170
+ "option": "Option",
171
+ "circuit name": "Circuit Name"
172
+ }
173
+
174
  SPEC_JP = {
175
  "Manufacture Name": "メーカー",
176
  "Circuit Name": "回路番号",
177
  "Load Name": "負荷名称",
178
  "Breaking Capacity": "遮断容量",
179
  "AF": "フレーム(AF)",
180
+ "AT": "トリップ(AT)",
181
+ "Option": "オプション",
182
+ "Type": "タイプ",
183
+ "Equipment Type": "機器種別"
184
  }
185
 
186
  MANUFACTURER_JP_MAP = {
 
191
  "LS ELECTRIC": "LS ELECTRIC"
192
  }
193
 
194
+ # ✅ STRICT WHITELIST
195
  KNOWN_MANUFACTURERS = {
196
  "MITSUBISHI ELECTRIC",
197
  "SIEMENS",
 
220
  return img[max(0,y1):min(h,y2), max(0,x1):min(w,x2)]
221
 
222
  def upscale(img):
223
+ return cv2.resize(img, None, fx=1.5, fy=1.5, interpolation=cv2.INTER_CUBIC)
224
 
225
  def rotate_image(img, a):
226
  if a == 90: return cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
 
254
  sharp = cv2.addWeighted(img, 1.5, blur, -0.5, 0)
255
  return sharp
256
 
257
+ def enhance_option(img):
258
+ # upscale for tiny text
259
+ img = cv2.resize(img, None, fx=2.2, fy=2.2, interpolation=cv2.INTER_CUBIC)
260
+
261
+ # --- HSV boost (brightness + saturation) ---
262
+ hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
263
+
264
+ # brighten dark labels
265
+ hsv[:,:,2] = cv2.add(hsv[:,:,2], 55)
266
+
267
+ # increase saturation so faded letters appear
268
+ hsv[:,:,1] = cv2.add(hsv[:,:,1], 35)
269
+
270
+ img = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
271
+
272
+ # --- gamma correction (critical for dark plastic labels) ---
273
+ gamma = 1.6
274
+ inv = 1.0 / gamma
275
+ table = np.array([(i/255.0)**inv * 255 for i in range(256)]).astype("uint8")
276
+ img = cv2.LUT(img, table)
277
+
278
+ # grayscale for OCR
279
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
280
+
281
+ # denoise
282
+ gray = cv2.fastNlMeansDenoising(gray, h=10)
283
+
284
+ # strong CLAHE (local contrast)
285
+ clahe = cv2.createCLAHE(4.0, (8,8))
286
+ gray = clahe.apply(gray)
287
+
288
+ # strong sharpen → reveals dark letters
289
+ blur = cv2.GaussianBlur(gray, (0,0), 1.3)
290
+ sharp = cv2.addWeighted(gray, 1.9, blur, -0.9, 0)
291
+
292
+ return cv2.cvtColor(sharp, cv2.COLOR_GRAY2BGR)
293
+
294
+ def normalize_option_text(text):
295
+ if not text:
296
+ return ""
297
+
298
+ t = text.upper().strip()
299
+
300
+ # ---- orientation / OCR confusion map
301
+ confusion_map = {
302
+ "7": "A", # upside-down A
303
+ "V": "L", # upside-down L
304
+ "1": "L",
305
+ "|": "L",
306
+ "I": "L",
307
+ "Y": "L",
308
+ "4": "A"
309
+ }
310
+
311
+ fixed = "".join(confusion_map.get(c, c) for c in t)
312
+
313
+ # direct match
314
+ if fixed in VALID_OPTIONS:
315
+ return fixed
316
+
317
+ # fuzzy match (core industrial step)
318
+ m = get_close_matches(fixed, VALID_OPTIONS, n=1, cutoff=0.4)
319
+ if m:
320
+ return m[0]
321
+
322
+ return fixed
323
+
324
  def save_crop(label, img):
325
  cv2.imwrite(os.path.join(CROP_DIR, f"{label}.jpg"), img)
326
 
 
354
 
355
  def gpt_single_ocr(label, img):
356
  b64 = img_to_base64(img)
357
+
358
  rules = {
359
  "Manufacture Name": "Read manufacturer name in English only.",
360
  "Circuit Name": "Read the FULL text exactly as printed.",
361
  "Load Name": "Read exact text.",
362
  "AF": "Read the FULL text exactly as printed.",
363
  "AT": "Read the FULL text exactly as printed.",
364
+ "Breaking Capacity": "Read the FULL text exactly as printed.",
365
+
366
+ # ⭐ NEW — OPTION PROMPT
367
+ "Option": (
368
+ "Industrial breaker OPTION label. "
369
+ "Text is short (1–10 characters). "
370
+ "Image may be dark, faint, small, or rotated. "
371
+ "Return ONLY the exact printed text. "
372
+ "If unreadable return empty."
373
+ )
374
  }
375
 
376
  r = client.chat.completions.create(
377
  model="gpt-5.2",
378
  messages=[
379
+ {
380
+ "role": "system",
381
+ "content": (
382
+ "You are an industrial electrical label OCR engine. "
383
+ "Extract text exactly as printed. "
384
+ "Text may be tiny, dark, or rotated. "
385
+ "Return only text."
386
+ )
387
+ },
388
  {"role":"user","content":[
389
  {"type":"text","text":rules[label]},
390
  {"type":"image_url","image_url":{"url":f"data:image/jpeg;base64,{b64}"}}
 
392
  ],
393
  temperature=0
394
  )
395
+
396
  return r.choices[0].message.content.strip()
397
 
398
+
399
  # ================= OCR CORE =================
400
 
401
  def gpt_ocr(label, img):
402
 
403
+ # ================= MANUFACTURER =================
404
  if label in ["Manufacture Name","Load Name"]:
405
  img = enhance(img)
406
  save_crop(label, img)
407
  t = gpt_single_ocr(label, img)
408
  return normalize_manufacturer(t) if label=="Manufacture Name" else remove_spaces_only(t)
409
 
410
+ # ================= BREAKING CAPACITY =================
411
  if label == "Breaking Capacity":
412
  img = enhance_breaking_capacity(img)
413
  t = gpt_single_ocr(label, img)
 
417
 
418
 
419
  # ========= ROTATION BASED LABELS =========
420
+
421
  best_text = ""
422
  best_score = -1
423
  best_img = None
424
 
425
+ # IMPORTANT — choose preprocessing based on label
426
+ if label == "Option":
427
+ base = enhance_option(img)
428
+ else:
429
+ base = enhance(img)
430
 
431
+ for a in [0, 90, 180, 270]:
432
+ rimg = rotate_image(base, a)
433
 
434
  try:
435
  t = gpt_single_ocr(label, rimg)
 
486
  continue
487
 
488
 
489
+ # ================= OPTION ⭐ NEW =================
490
+ elif label == "Option":
491
+
492
+ # ⭐ IMPORTANT — normalize BEFORE scoring
493
+ candidate = normalize_option_text(clean)
494
+
495
+ # Option text usually short
496
+ if len(candidate) <= 10:
497
+ score += 40
498
+
499
+ # prefer valid option hits (strong signal)
500
+ if candidate in VALID_OPTIONS:
501
+ score += 120
502
+
503
+ # prefer alphabetic (real options are alphabetic)
504
+ if re.search(r"[A-Za-z]", candidate):
505
+ score += 20
506
+
507
+ score += len(candidate)
508
+
509
  # Track best candidate
510
  if score > best_score:
511
  best_score = score
 
520
  return ""
521
 
522
 
 
523
  # ================= EXCEL VERIFICATION =================
524
 
525
  def normalize_header(s):
 
563
 
564
  if hdr is None:
565
  return pd.DataFrame([
566
+ ["Excel", "", "エラー", "ヘッダー行が見つかりません。"]
567
  ], columns=["仕様","検出値","Excelに存在?","備考"])
568
 
569
  df = raw.iloc[hdr+1:].copy()
570
  df.columns = raw.iloc[hdr]
571
  df.dropna(how="all", inplace=True)
572
 
 
 
 
 
 
 
 
 
 
 
573
  ccol = find_column(df, ["回路番号","回路"])
574
 
575
  if ccol is None:
 
590
 
591
  rows = []
592
 
593
+ # ⭐ LOOP THROUGH ALL FIELDS (NOW INCLUDES TYPE)
594
  for k, jp in SPEC_JP.items():
595
 
596
  detected_value = det.get(k, "").strip()
597
+
598
+ # ---- column search keys ----
599
+ keys = [jp.replace("(","").replace(")",""), jp[:2]]
600
+
601
+ # ⭐ OPTION SUPPORT
602
+ if k == "Option":
603
+ keys += ["オプション", "オプシ", "ション"]
604
+
605
+ # ⭐ TYPE SUPPORT (NEW 🔥)
606
+ if k == "Type":
607
+ keys += ["タイプ"]
608
+
609
+ if k == "Equipment Type":
610
+ keys += ["機器 種別"]
611
+
612
+ col = find_column(df, keys)
613
 
614
  if col is None:
615
  rows.append([
 
622
 
623
  excel_value = str(target[col])
624
 
625
+ if k == "Manufacture Name":
626
+ detected_jp = MANUFACTURER_JP_MAP.get(detected_value, detected_value)
627
+ excel_jp = excel_value
628
+ ok = normalize_for_compare(detected_jp) == normalize_for_compare(excel_jp)
629
+ rows.append([
630
+ jp,
631
+ detected_jp, # 👈 show Japanese here
632
+ "YES" if ok else "NO",
633
+ "" if ok else f"Excel値: {excel_jp}"
634
+ ])
635
+ continue
636
 
637
+ # not detected in panel
638
  if not detected_value:
639
  rows.append([
640
  jp,
 
656
  return pd.DataFrame(rows,columns=["仕様","検出値","Excelに存在?","備考"])
657
 
658
 
 
 
659
  # ================= PIPELINE & UI =================
660
 
661
+ def bbox_area(p):
662
+ return p["width"] * p["height"]
663
 
664
  def run_pipeline(image, excel):
665
 
 
671
  return None, pd.DataFrame(), pd.DataFrame(), None, \
672
  "⚠️ **Please upload the breaker panel image before running verification.**"
673
 
674
+ # ================= DETECTION =================
675
  img = prepare_for_roboflow(image)
676
+
677
  preds = model.predict(
678
  img,
679
  confidence=int(CONF_THRESHOLD*100),
 
681
  ).json()["predictions"]
682
 
683
  vis = img.copy()
684
+ det = {}
685
+ best_boxes = {}
686
 
687
+ # ================= SELECT BEST BOX =================
688
  for p in preds:
689
+ raw_lab = p["class"]
690
+ lab = LABEL_MAP.get(raw_lab.lower(), raw_lab)
691
+
692
  if lab not in best_boxes:
693
+ best_boxes[lab] = p
694
  else:
695
+ if lab == "Circuit Name":
696
+ if bbox_area(p) < bbox_area(best_boxes[lab]):
697
+ best_boxes[lab] = p
698
  else:
699
+ if p["confidence"] > best_boxes[lab]["confidence"]:
700
+ best_boxes[lab] = p
701
+
702
+ # ================= PARALLEL OCR =================
703
+ def process_label(item):
704
+ lab, p = item
705
+
706
+ x, y, w, h = map(int, [p["x"], p["y"], p["width"], p["height"]])
707
+ x1, y1, x2, y2 = x - w//2, y - h//2, x + w//2, y + h//2
708
+
709
+ roi = upscale(crop_with_optional_expand(img, x1, y1, x2, y2, lab))
710
+
711
+ try:
712
+ value = gpt_ocr(lab, roi)
713
+ except Exception as e:
714
+ value = ""
715
+ print(f"OCR Error for {lab}: {e}")
716
 
717
+ return lab, value, (x1, y1, x2, y2)
 
 
 
 
 
718
 
719
+ # 🔥 Parallel execution (IMPORTANT)
720
+ with ThreadPoolExecutor(max_workers=5) as executor:
721
+ results = list(executor.map(process_label, best_boxes.items()))
722
 
723
+ # ================= COLLECT RESULTS =================
724
+ for lab, value, (x1, y1, x2, y2) in results:
725
+ det[lab] = value
726
+ cv2.rectangle(vis, (x1, y1), (x2, y2), (0, 255, 0), 2)
727
 
728
+ # ================= TYPE CLASSIFICATION =================
729
+ pred_type, conf = predict_breaker_type(img)
730
+ det["Type"] = pred_type
731
+
732
+ eq_type, eq_conf, _ = predict_equipment_type(img)
733
+ det["Equipment Type"] = eq_type
734
+
735
+ # ================= DATAFRAMES =================
736
+ ocr_df = pd.DataFrame(det.items(), columns=["Field", "Extracted Text"])
737
+ verify_df = verify_excel(excel, det)
738
+
739
+ # ================= SAVE =================
740
+ out = "verification_result.xlsx"
741
+ with pd.ExcelWriter(out, engine="openpyxl") as w:
742
+ ocr_df.to_excel(w, "OCR_Output", index=False)
743
+ verify_df.to_excel(w, "Verification", index=False)
744
+
745
+ return vis, ocr_df, verify_df, out, ""
746
 
747
  apple_dark_pink_css = """
748
  @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;500;600&display=swap');
 
793
  }
794
  """
795
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
796
  with gr.Blocks(
797
  theme=gr.themes.Soft(primary_hue="pink"),
798
  css=apple_dark_pink_css
 
819
  [img_out, t1, t2, f, status_msg]
820
  )
821
 
822
+ demo.launch()