Chung-Fan commited on
Commit
9415ed0
·
1 Parent(s): 1f7bb44

update app

Browse files
Files changed (1) hide show
  1. app.py +46 -15
app.py CHANGED
@@ -1,45 +1,76 @@
1
  import torch
2
  from src.model import CRNN
3
- from PIL import Image as PILImage
4
  import torchvision.transforms as transforms
5
  import gradio as gr
 
 
 
 
 
 
 
 
6
 
7
- # Load CRNN model
8
  model = CRNN(img_height=32, img_width=100, img_channel=1, num_class=37, rnn_hidden=256)
9
- model.load_state_dict(torch.load("crnn_gpu.pt", map_location="cpu"))
10
  model.eval()
11
 
 
 
 
12
  alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
13
 
14
  def ctc_decode(preds):
 
15
  preds = preds.argmax(2).transpose(1,0).contiguous().view(-1)
16
  decoded = []
17
  prev_idx = -1
18
  for idx in preds:
19
- if idx != prev_idx and idx != 0:
20
  decoded.append(alphabet[idx-1])
21
  prev_idx = idx
22
  return ''.join(decoded)
23
 
 
 
 
 
 
 
 
 
 
24
  transform = transforms.Compose([
25
- transforms.Grayscale(),
26
- transforms.Resize((32,100)),
27
  transforms.ToTensor(),
28
  transforms.Normalize((0.5,), (0.5,))
29
  ])
30
 
31
- def ocr(image: PILImage.Image):
32
- img_tensor = transform(image).unsqueeze(0)
33
- with torch.no_grad():
34
- preds = model(img_tensor)
35
- text = ctc_decode(preds)
36
- return text
 
 
 
 
 
 
37
 
 
 
 
38
  iface = gr.Interface(
39
  fn=ocr,
40
- inputs=gr.Image(type="pil", label="Upload and crop image"),
41
  outputs="text",
42
- title="CRNN OCR"
 
43
  )
44
 
45
- iface.launch()
 
 
1
  import torch
2
  from src.model import CRNN
3
+ from PIL import Image
4
  import torchvision.transforms as transforms
5
  import gradio as gr
6
+ import os
7
+
8
+ # ----------------------------
9
+ # 1️⃣ Load CRNN model
10
+ # ----------------------------
11
+ MODEL_PATH = "crnn_gpu.pt"
12
+ if not os.path.exists(MODEL_PATH):
13
+ raise FileNotFoundError(f"{MODEL_PATH} not found! Make sure it's in the Space root.")
14
 
 
15
  model = CRNN(img_height=32, img_width=100, img_channel=1, num_class=37, rnn_hidden=256)
16
+ model.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
17
  model.eval()
18
 
19
+ # ----------------------------
20
+ # 2️⃣ Characters and CTC decoding
21
+ # ----------------------------
22
  alphabet = '0123456789abcdefghijklmnopqrstuvwxyz'
23
 
24
  def ctc_decode(preds):
25
+ """Greedy CTC decoder"""
26
  preds = preds.argmax(2).transpose(1,0).contiguous().view(-1)
27
  decoded = []
28
  prev_idx = -1
29
  for idx in preds:
30
+ if idx != prev_idx and idx != 0: # skip duplicates & blank
31
  decoded.append(alphabet[idx-1])
32
  prev_idx = idx
33
  return ''.join(decoded)
34
 
35
+ # ----------------------------
36
+ # 3️⃣ Preprocessing
37
+ # ----------------------------
38
+ def to_grayscale(img: Image.Image):
39
+ """Convert any image type to grayscale"""
40
+ if img.mode != "L":
41
+ return img.convert("L")
42
+ return img
43
+
44
  transform = transforms.Compose([
45
+ transforms.Lambda(to_grayscale), # convert any input image to grayscale
46
+ transforms.Resize((32, 100)), # match CRNN input
47
  transforms.ToTensor(),
48
  transforms.Normalize((0.5,), (0.5,))
49
  ])
50
 
51
+ # ----------------------------
52
+ # 4️⃣ OCR function
53
+ # ----------------------------
54
+ def ocr(image: Image.Image):
55
+ try:
56
+ img_tensor = transform(image).unsqueeze(0) # add batch dimension
57
+ with torch.no_grad():
58
+ preds = model(img_tensor)
59
+ text = ctc_decode(preds)
60
+ return text
61
+ except Exception as e:
62
+ return f"Error during inference: {e}"
63
 
64
+ # ----------------------------
65
+ # 5️⃣ Gradio interface
66
+ # ----------------------------
67
  iface = gr.Interface(
68
  fn=ocr,
69
+ inputs=gr.Image(type="pil", label="Upload any image (RGB, RGBA, etc.)"),
70
  outputs="text",
71
+ title="CRNN OCR",
72
+ description="Upload an image and get the OCR text prediction."
73
  )
74
 
75
+ # Launch
76
+ iface.launch(share=True)