AnasHXH commited on
Commit
6399a41
·
verified ·
1 Parent(s): 510c371

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -0
app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import cv2
4
+ import time
5
+ import json
6
+ import math
7
+ import torch
8
+ import numpy as np
9
+ import pandas as pd
10
+ from PIL import Image
11
+ import gradio as gr
12
+
13
+ # ----------------------------
14
+ # Config
15
+ # ----------------------------
16
+ DEFAULT_MODEL_PATH = os.getenv("MODEL_PATH", "weights/best.pt")
17
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # ----------------------------
20
+ # Model loading (edit this)
21
+ # ----------------------------
22
+ _model = None
23
+
24
+ def load_model(model_path: str = DEFAULT_MODEL_PATH):
25
+ """
26
+ Load your trained model once when the Space boots.
27
+ Replace the placeholder with your code.
28
+ """
29
+ global _model
30
+ if _model is not None:
31
+ return _model
32
+
33
+ # >>> YOUR MODEL HERE <<<
34
+ # Example (PyTorch scripted/ckpt):
35
+ # ckpt = torch.load(model_path, map_location=DEVICE)
36
+ # model = MyNet(...)
37
+ # model.load_state_dict(ckpt["state_dict"] if "state_dict" in ckpt else ckpt)
38
+ # model.to(DEVICE).eval()
39
+ #
40
+ # For YOLO-like:
41
+ # from ultralytics import YOLO
42
+ # model = YOLO(model_path)
43
+
44
+ # Placeholder “no-model” to keep UI running:
45
+ class DummyModel:
46
+ def __init__(self):
47
+ pass
48
+ _model = DummyModel()
49
+ return _model
50
+
51
+ # ----------------------------
52
+ # Inference wrapper (edit this)
53
+ # ----------------------------
54
+ def infer(image_bgr: np.ndarray, conf: float = 0.25):
55
+ """
56
+ Return defects as a list of boxes: [x1,y1,x2,y2,score,label]
57
+ OR return a binary mask (H,W) where 1=defect.
58
+ Edit this to call your model.
59
+ """
60
+ model = load_model()
61
+
62
+ # >>> YOUR MODEL HERE <<<
63
+ # Option A (detection):
64
+ # results = model(image_bgr[..., ::-1]) # example if model expects RGB
65
+ # boxes = [[x1,y1,x2,y2,score,"defect"], ...]
66
+ # return {"type": "boxes", "boxes": boxes}
67
+
68
+ # Option B (segmentation):
69
+ # mask = your_segmentation(image_bgr) # 0/1 uint8
70
+ # return {"type": "mask", "mask": mask}
71
+
72
+ # --------- PLACEHOLDER (edge blobs as fake defects) ---------
73
+ gray = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2GRAY)
74
+ e = cv2.Canny(gray, 50, 150)
75
+ cnts, _ = cv2.findContours(e, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
76
+ boxes = []
77
+ h, w = gray.shape[:2]
78
+ for c in cnts:
79
+ x, y, bw, bh = cv2.boundingRect(c)
80
+ if bw * bh < max(0.0005 * w * h, 150): # skip tiny
81
+ continue
82
+ boxes.append([x, y, x + bw, y + bh, 0.5, "defect"])
83
+ if len(boxes) >= 20:
84
+ break
85
+ return {"type": "boxes", "boxes": boxes}
86
+
87
+ # ----------------------------
88
+ # Utilities
89
+ # ----------------------------
90
+ def draw_boxes_with_x(image_bgr: np.ndarray, boxes, thickness: int = 3):
91
+ img = image_bgr.copy()
92
+ color = (0, 0, 255) # red in BGR
93
+ for (x1, y1, x2, y2, score, label) in boxes:
94
+ x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
95
+ cv2.rectangle(img, (x1, y1), (x2, y2), color, thickness)
96
+ # draw X
97
+ cv2.line(img, (x1, y1), (x2, y2), color, thickness)
98
+ cv2.line(img, (x1, y2), (x2, y1), color, thickness)
99
+ # label
100
+ txt = f"{label}:{score:.2f}"
101
+ cv2.putText(img, txt, (x1, max(y1 - 6, 0)),
102
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2, cv2.LINE_AA)
103
+ return img
104
+
105
+ def boxes_from_mask(mask: np.ndarray, min_area: int = 50):
106
+ mask = (mask > 0).astype(np.uint8)
107
+ cnts, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
108
+ out = []
109
+ for c in cnts:
110
+ x, y, w, h = cv2.boundingRect(c)
111
+ if w * h >= min_area:
112
+ out.append([x, y, x + w, y + h, 1.0, "defect"])
113
+ return out
114
+
115
+ def to_csv_file(rows, path="/tmp/defect_report.csv"):
116
+ df = pd.DataFrame(rows, columns=["x1", "y1", "x2", "y2", "score", "label"])
117
+ df.to_csv(path, index=False)
118
+ return path, df
119
+
120
+ # ----------------------------
121
+ # Gradio handlers
122
+ # ----------------------------
123
+ def process(image: Image.Image, conf: float, draw_x: bool, min_area: int):
124
+ if image is None:
125
+ return None, pd.DataFrame(), None
126
+
127
+ # PIL -> BGR np
128
+ img_rgb = np.array(image.convert("RGB"))
129
+ img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
130
+
131
+ res = infer(img_bgr, conf=conf)
132
+
133
+ if res["type"] == "mask":
134
+ boxes = boxes_from_mask(res["mask"], min_area=min_area)
135
+ else:
136
+ boxes = [b for b in res["boxes"] if b[4] >= conf]
137
+
138
+ # draw
139
+ vis = draw_boxes_with_x(img_bgr, boxes) if draw_x else img_bgr.copy()
140
+ vis_rgb = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB)
141
+
142
+ # csv + table
143
+ csv_path, df = to_csv_file(boxes)
144
+
145
+ return Image.fromarray(vis_rgb), df, csv_path
146
+
147
+ # ----------------------------
148
+ # UI
149
+ # ----------------------------
150
+ with gr.Blocks(title="AI-Driven EL Defect Recognition") as demo:
151
+ gr.Markdown(
152
+ "## AI-Driven Defect Recognition in EL Images\n"
153
+ "Upload an electroluminescence (EL) image. The app detects defective cells, "
154
+ "draws a red square with an X, and provides a CSV report."
155
+ )
156
+ with gr.Row():
157
+ with gr.Column():
158
+ inp = gr.Image(type="pil", label="Upload EL image")
159
+ conf = gr.Slider(0.0, 1.0, value=0.25, step=0.01, label="Confidence threshold")
160
+ draw_x = gr.Checkbox(True, label="Draw red box + X")
161
+ min_area = gr.Slider(10, 5000, value=120, step=10, label="Min defect area (pixels, for masks)")
162
+ run_btn = gr.Button("Run inference", variant="primary")
163
+ with gr.Column():
164
+ out_img = gr.Image(type="pil", label="Annotated output")
165
+ out_table = gr.Dataframe(headers=["x1","y1","x2","y2","score","label"], label="Defect report (preview)")
166
+ out_csv = gr.File(label="Download CSV")
167
+
168
+ run_btn.click(process, inputs=[inp, conf, draw_x, min_area],
169
+ outputs=[out_img, out_table, out_csv])
170
+
171
+ if __name__ == "__main__":
172
+ load_model() # warmup
173
+ demo.launch()