APIMONSTER commited on
Commit
72b3efa
Β·
verified Β·
1 Parent(s): 5a658a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -153
app.py CHANGED
@@ -1,171 +1,87 @@
1
  # app.py
2
- import cv2
3
- import json
4
- import tempfile
5
  import numpy as np
6
- import re
7
- import paddle
8
- import paddle.nn as nn
9
- from ultralytics import YOLO
10
  import gradio as gr
11
- from datetime import datetime
12
- from pathlib import Path
13
-
14
- # ─── 0) PlateOCR model definition ────────────────────────────────
15
- MAX_SEQ_LEN = 15
16
- LABEL_MAP = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ "
17
- label_to_int = {c: i for i,c in enumerate(LABEL_MAP)}
18
- int_to_label = {i: c for c,i in label_to_int.items()}
19
-
20
- class OCRHead(nn.Layer):
21
- def __init__(self, in_feats):
22
- super().__init__()
23
- self.fc = nn.Linear(in_feats, MAX_SEQ_LEN * len(LABEL_MAP))
24
- def forward(self, x):
25
- B = x.shape[0]
26
- logits = self.fc(x).reshape([B, MAX_SEQ_LEN, -1])
27
- return logits
28
-
29
- class PlateOCRTransfer(nn.Layer):
30
- def __init__(self):
31
- super().__init__()
32
- # same backbone as your training script
33
- self.backbone = nn.Sequential(
34
- nn.Conv2D(3,32,3,padding=1), nn.BatchNorm2D(32), nn.ReLU(), nn.MaxPool2D(2),
35
- nn.Conv2D(32,64,3,padding=1), nn.BatchNorm2D(64), nn.ReLU(), nn.MaxPool2D(2),
36
- nn.Conv2D(64,128,3,padding=1), nn.BatchNorm2D(128), nn.ReLU(), nn.MaxPool2D(2),
37
- nn.Dropout(0.25)
38
- )
39
- # determine flattened feature size
40
- dummy = paddle.randn([1,3,32,128])
41
- flat = paddle.flatten(self.backbone(dummy),1).shape[1]
42
- self.head = OCRHead(flat)
43
-
44
- def forward(self, x):
45
- x = self.backbone(x)
46
- x = paddle.flatten(x,1)
47
- return self.head(x)
48
-
49
- # ─── 1) Greedy decode ─────────────────────────────────────────────
50
- def greedy_decode(logits):
51
- # logits: [1, T, C]
52
- pred = logits.argmax(axis=2).numpy()[0] # [T]
53
- res = []
54
- prev = -1
55
- for idx in pred:
56
- if idx != prev and idx < len(LABEL_MAP):
57
- res.append(LABEL_MAP[idx])
58
- prev = idx
59
- return "".join(res).strip()
60
-
61
- # ─── 2) Load detection & OCR models ───────────────────────────────
62
- yolo_model = YOLO("models/best.pt")
63
-
64
- ocr_model = PlateOCRTransfer()
65
- checkpoint = paddle.load("models/best_plate_model.pdparams")
66
- ocr_model.set_state_dict(checkpoint)
67
- ocr_model.eval()
68
-
69
- # ─── 3) Plate formatting helper ───────────────────────────────────
70
- def format_turkish_plate(plate: str) -> str:
71
- m = re.match(r"^(\d{2})([A-Z]{1,3})(\d{2,4})$", plate.replace(" ", ""))
72
- if m:
73
- return f"{m.group(1)} {m.group(2)} {m.group(3)}"
74
- return "Unknown"
75
-
76
- # ─── 4) Single-image pipeline ────────────────────────────────────
77
- def process_image(img_np, conf_thresh=0.25):
78
- img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
79
- res = yolo_model(img_bgr, iou=0.3, conf=conf_thresh)[0]
80
- boxes = res.boxes.xyxy.cpu().numpy()
81
- scores = res.boxes.conf.cpu().numpy()
82
-
83
- annotated = img_bgr.copy()
84
- count = 0
85
- for (x1,y1,x2,y2), conf in zip(boxes, scores):
86
- x1,y1,x2,y2 = map(int,(x1,y1,x2,y2))
87
- crop = annotated[y1:y2, x1:x2]
88
- if crop.size == 0:
89
- continue
90
-
91
- # preprocess for PlateOCR
92
- plate = cv2.resize(crop, (128,32)).astype("float32") / 255.0
93
- inp = paddle.to_tensor(plate.transpose(2,0,1)[None,:,:,:])
94
- with paddle.no_grad():
95
- logits = ocr_model(inp) # [1,T,C]
96
- txt = greedy_decode(logits)
97
- fmtd = format_turkish_plate(txt)
98
-
99
- label = f"{fmtd} ({conf:.2f})"
100
- cv2.rectangle(annotated,(x1,y1),(x2,y2),(0,255,0),2)
101
- cv2.putText(annotated, label, (x1, y1-6),
102
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,255,0), 2)
103
- count += 1
104
-
105
- out = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
106
- return out, f"{count} plate(s) detected"
107
-
108
- # ─── 5) Video pipeline ───────────────────────────────────────────
109
- def process_video(video_file, conf_thresh=0.25):
110
  cap = cv2.VideoCapture(video_file)
111
  fps = cap.get(cv2.CAP_PROP_FPS)
112
- w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
113
- h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
114
-
115
- tmp_out = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
116
- writer = cv2.VideoWriter(tmp_out, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w,h))
117
-
118
- logs = []
119
- frame_i = 0
120
  while True:
121
- ret, frame = cap.read()
122
  if not ret: break
123
- frame_i += 1
124
- t = frame_i / fps
125
-
126
- res = yolo_model(frame, iou=0.3, conf=conf_thresh)[0]
127
- boxes = res.boxes.xyxy.cpu().numpy()
128
-
129
- for (x1,y1,x2,y2) in boxes:
130
- x1,y1,x2,y2 = map(int,(x1,y1,x2,y2))
131
  crop = frame[y1:y2, x1:x2]
132
  if crop.size==0: continue
133
-
134
- plate = cv2.resize(crop,(128,32)).astype("float32")/255.0
135
- inp = paddle.to_tensor(plate.transpose(2,0,1)[None,:,:,:])
136
- with paddle.no_grad():
137
- logits = ocr_model(inp)
138
- txt = greedy_decode(logits)
139
- fmtd = format_turkish_plate(txt)
140
- if fmtd!="Unknown":
141
- logs.append({"time_s":round(t,2),"plate":fmtd})
142
- cv2.rectangle(frame,(x1,y1),(x2,y2),(0,255,0),2)
143
- cv2.putText(frame, fmtd, (x1,y1-6),
144
- cv2.FONT_HERSHEY_SIMPLEX,0.6,(0,255,0),2)
145
-
146
  writer.write(frame)
147
-
148
  cap.release(); writer.release()
149
- with open("output.json","w") as f:
150
- json.dump(logs,f,indent=2)
151
- return tmp_out
152
 
153
- # ─── 6) Gradio UI ───────────────────────────────────────────────
154
  with gr.Blocks() as demo:
155
- gr.Markdown("## πŸš— License Plate Detection & OCR")
156
  with gr.Row():
157
  with gr.Column():
158
- inp_img = gr.Image(type="numpy", label="Upload Image")
159
- inp_vid = gr.File(label="Upload Video (.mp4)")
160
- conf = gr.Slider(0,1,0.25,0.01, label="YOLO Confidence")
161
- b1 = gr.Button("Run Image")
162
- b2 = gr.Button("Run Video")
163
  with gr.Column():
164
- out_img = gr.Image(type="numpy", label="Annotated Image")
165
- out_vid = gr.Video(label="Annotated Video")
166
- out_txt = gr.Textbox(label="Status / JSON Path")
167
- b1.click(process_image, [inp_img, conf], [out_img, out_txt])
168
- b2.click(process_video, [inp_vid, conf], [out_vid, out_txt])
169
-
170
  if __name__=="__main__":
171
  demo.launch()
 
1
  # app.py
2
+ import cv2, json, tempfile, re
 
 
3
  import numpy as np
 
 
 
 
4
  import gradio as gr
5
+ from ultralytics import YOLO
6
+ from paddleocr import PaddleOCR
7
+
8
+ # 1) load detection + OCR
9
+ yolo = YOLO("models/best.pt")
10
+ ocr = PaddleOCR(
11
+ det_model_dir=None, # turn off internal detector
12
+ rec_model_dir="models/ocr_model", # inference export dir
13
+ use_textline_orientation=True # orientation head
14
+ )
15
+
16
+ # 2) helper to enforce β€œDD AAA NNNN” style
17
+ def fmt_plate(s):
18
+ m = re.match(r"^(\d{2})([A-Z]{1,3})(\d{2,4})$", s.replace(" ",""))
19
+ return f"{m[1]} {m[2]} {m[3]}" if m else "Unknown"
20
+
21
+ # 3) image pipeline
22
+ def run_image(img, conf=0.25):
23
+ bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
24
+ res = yolo(bgr, conf=conf)[0]
25
+ out = bgr.copy()
26
+ for box,score in zip(res.boxes.xyxy.cpu().numpy(), res.boxes.conf.cpu().numpy()):
27
+ x1,y1,x2,y2 = map(int,box)
28
+ crop = out[y1:y2, x1:x2]
29
+ if crop.size==0: continue
30
+ plate_img = cv2.resize(crop,(128,32))
31
+ rec = ocr.ocr(plate_img, cls=True)[0]
32
+ txt = "".join(seg[1][0] for seg in rec)
33
+ label = fmt_plate(txt)
34
+ cv2.rectangle(out,(x1,y1),(x2,y2),(0,255,0),2)
35
+ cv2.putText(out, f"{label} {score:.2f}", (x1,y1-5),
36
+ cv2.FONT_HERSHEY_SIMPLEX,0.6,(0,255,0),2)
37
+ return cv2.cvtColor(out,cv2.COLOR_BGR2RGB), f"{len(res.boxes)} plates detected"
38
+
39
+ # 4) video pipeline (frame‐by‐frame, writes output.json):
40
+ def run_video(video_file, conf=0.25):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  cap = cv2.VideoCapture(video_file)
42
  fps = cap.get(cv2.CAP_PROP_FPS)
43
+ w,h = int(cap.get(3)), int(cap.get(4))
44
+ out_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
45
+ writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w,h))
46
+ records = []
47
+ idx = 0
 
 
 
48
  while True:
49
+ ret,frame = cap.read()
50
  if not ret: break
51
+ idx+=1; t=idx/fps
52
+ res = yolo(frame, conf=conf)[0]
53
+ for (x1,y1,x2,y2) in res.boxes.xyxy.cpu().numpy().astype(int):
 
 
 
 
 
54
  crop = frame[y1:y2, x1:x2]
55
  if crop.size==0: continue
56
+ plate = cv2.resize(crop,(128,32))
57
+ rec = ocr.ocr(plate, cls=True)[0]
58
+ txt = "".join(seg[1][0] for seg in rec)
59
+ label = fmt_plate(txt)
60
+ score = min(seg[1][1] for seg in rec) if rec else 0.0
61
+ if label!="Unknown":
62
+ records.append({"time_s":round(t,2),"plate":label,"conf":round(score,3)})
63
+ cv2.rectangle(frame,(x1,y1),(x2,y2),(0,255,0),2)
64
+ cv2.putText(frame,label,(x1,y1-5),cv2.FONT_HERSHEY_SIMPLEX,0.6,(0,255,0),2)
 
 
 
 
65
  writer.write(frame)
 
66
  cap.release(); writer.release()
67
+ with open("output.json","w") as f: json.dump(records,f,indent=2)
68
+ return out_path
 
69
 
70
+ # 5) Gradio UI
71
  with gr.Blocks() as demo:
72
+ gr.Markdown("## πŸš— Plate Detection + Recognition")
73
  with gr.Row():
74
  with gr.Column():
75
+ img_in = gr.Image(type="numpy", label="Image")
76
+ vid_in = gr.File(label="Video (.mp4)")
77
+ conf = gr.Slider(0,1,0.25,0.01, label="YOLO confidence")
78
+ b1 = gr.Button("Process Image")
79
+ b2 = gr.Button("Process Video")
80
  with gr.Column():
81
+ img_out = gr.Image(type="numpy", label="Result")
82
+ vid_out = gr.Video(label="Annotated Video")
83
+ txt_out = gr.Textbox(label="Status / JSON path")
84
+ b1.click(run_image, [img_in,conf],[img_out,txt_out])
85
+ b2.click(run_video, [vid_in,conf],[vid_out,txt_out])
 
86
  if __name__=="__main__":
87
  demo.launch()