Steven1310 commited on
Commit
24f7ea5
·
1 Parent(s): 962c956

Initial Captcha OCR Space

Browse files
Files changed (1) hide show
  1. app.py +50 -106
app.py CHANGED
@@ -4,129 +4,73 @@ import onnxruntime as rt
4
  from torchvision import transforms as T
5
  from pathlib import Path
6
  from PIL import Image
7
- from fastapi import FastAPI, UploadFile, File, Body
8
- from fastapi.responses import JSONResponse
9
- from pydantic import BaseModel
10
  from utils.tokenizer_base import Tokenizer
 
11
  import io
12
- import os
13
  import base64
14
- import gradio as gr
15
 
16
- # =========================
17
  # MODEL SETUP
18
- # =========================
19
- model_path = "models/model.onnx"
20
- cwd = Path(__file__).parent.resolve()
21
- model_file = os.path.join(cwd, model_path)
22
-
23
- if not os.path.exists(model_file):
24
- raise FileNotFoundError(f"Model not found at {model_file}")
25
 
26
  img_size = (32, 128)
27
  vocab = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
28
  tokenizer = Tokenizer(vocab)
29
 
 
 
 
 
 
30
 
31
- def to_numpy(tensor):
32
- return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
33
-
34
-
35
- def get_transform(img_size):
36
- return T.Compose([
37
- T.Resize(img_size, T.InterpolationMode.BICUBIC),
38
- T.ToTensor(),
39
- T.Normalize(0.5, 0.5),
40
- ])
41
-
42
 
43
- def load_model(model_file):
44
- transform = get_transform(img_size)
45
- onnx_model = onnx.load(model_file)
46
- onnx.checker.check_model(onnx_model)
47
- session = rt.InferenceSession(model_file)
48
- return transform, session
49
 
 
 
50
 
51
- transform, session = load_model(model_file)
52
 
53
- # =========================
54
- # SHARED INFERENCE LOGIC
55
- # =========================
56
- def predict_from_image(img: Image.Image) -> str:
57
  x = transform(img.convert("RGB")).unsqueeze(0)
58
- ort_inputs = {session.get_inputs()[0].name: to_numpy(x)}
59
- logits = session.run(None, ort_inputs)[0]
60
  probs = torch.tensor(logits).softmax(-1)
61
  preds, _ = tokenizer.decode(probs)
62
  return preds[0]
63
 
64
- # =========================
65
- # FASTAPI SETUP
66
- # =========================
67
- app = FastAPI(title="OCR CAPTCHA API")
68
-
69
- class Base64ImageRequest(BaseModel):
70
- image_base64: str
71
-
72
-
73
- @app.post("/predict/file")
74
- async def predict_file(file: UploadFile = File(...)):
75
- """
76
- Accepts raw bytes (multipart/form-data)
77
- """
78
- try:
79
- contents = await file.read()
80
- img = Image.open(io.BytesIO(contents))
81
- result = predict_from_image(img)
82
- return {"predicted_text": result}
83
- except Exception as e:
84
- return JSONResponse({"error": str(e)}, status_code=500)
85
-
86
-
87
- @app.post("/predict/base64")
88
- async def predict_base64(payload: Base64ImageRequest):
89
- """
90
- Accepts base64-encoded image
91
- """
92
- try:
93
- image_bytes = base64.b64decode(payload.image_base64)
94
- img = Image.open(io.BytesIO(image_bytes))
95
- result = predict_from_image(img)
96
- return {"predicted_text": result}
97
- except Exception as e:
98
- return JSONResponse({"error": str(e)}, status_code=500)
99
-
100
- # =========================
101
- # GRADIO UI
102
- # =========================
103
- def gradio_predict(img: Image.Image):
104
- if img is None:
105
- return ""
106
- return predict_from_image(img)
107
-
108
-
109
- gradio_ui = gr.Interface(
110
- fn=gradio_predict,
111
- inputs=gr.Image(type="pil", label="Input Image"),
112
- outputs=gr.Textbox(label="Predicted Text"),
113
- title="OCR CAPTCHA Solver",
114
- description="OCR model for captcha images (letters + numbers).",
115
- examples=[
116
- "examples/1.png",
117
- "examples/2.jpg",
118
- ],
119
- )
120
-
121
- # =========================
122
- # MOUNT GRADIO INTO FASTAPI
123
- # =========================
124
- app = gr.mount_gradio_app(app, gradio_ui, path="/")
125
-
126
-
127
- # =========================
128
- # LOCAL RUN
129
- # =========================
130
- # if __name__ == "__main__":
131
- # import uvicorn
132
- # uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)
 
4
  from torchvision import transforms as T
5
  from pathlib import Path
6
  from PIL import Image
 
 
 
7
  from utils.tokenizer_base import Tokenizer
8
+ import gradio as gr
9
  import io
 
10
  import base64
11
+ import os
12
 
13
+ # =====================
14
  # MODEL SETUP
15
+ # =====================
16
+ model_file = Path(__file__).parent / "models/model.onnx"
17
+ if not model_file.exists():
18
+ raise RuntimeError(f"Model not found at {model_file}")
 
 
 
19
 
20
  img_size = (32, 128)
21
  vocab = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
22
  tokenizer = Tokenizer(vocab)
23
 
24
+ transform = T.Compose([
25
+ T.Resize(img_size, T.InterpolationMode.BICUBIC),
26
+ T.ToTensor(),
27
+ T.Normalize(0.5, 0.5),
28
+ ])
29
 
30
+ session = rt.InferenceSession(str(model_file))
 
 
 
 
 
 
 
 
 
 
31
 
 
 
 
 
 
 
32
 
33
+ def to_numpy(t):
34
+ return t.detach().cpu().numpy()
35
 
 
36
 
37
+ def infer(img: Image.Image):
 
 
 
38
  x = transform(img.convert("RGB")).unsqueeze(0)
39
+ logits = session.run(None, {session.get_inputs()[0].name: to_numpy(x)})[0]
 
40
  probs = torch.tensor(logits).softmax(-1)
41
  preds, _ = tokenizer.decode(probs)
42
  return preds[0]
43
 
44
+
45
+ # =====================
46
+ # GRADIO FUNCTIONS
47
+ # =====================
48
+ def predict_image(img):
49
+ return infer(img)
50
+
51
+
52
+ def predict_base64(b64: str):
53
+ img_bytes = base64.b64decode(b64)
54
+ img = Image.open(io.BytesIO(img_bytes))
55
+ return infer(img)
56
+
57
+
58
+ # =====================
59
+ # GRADIO APP (REQUIRED)
60
+ # =====================
61
+ with gr.Blocks(title="Captcha OCR") as demo:
62
+ gr.Markdown("# Captcha OCR")
63
+ gr.Markdown("OCR for captcha images (letters & numbers)")
64
+
65
+ with gr.Tab("Image Upload"):
66
+ img = gr.Image(type="pil")
67
+ out = gr.Textbox()
68
+ gr.Button("Predict").click(predict_image, img, out)
69
+
70
+ with gr.Tab("Base64 API"):
71
+ b64 = gr.Textbox(label="Base64 Image")
72
+ out2 = gr.Textbox()
73
+ gr.Button("Predict").click(predict_base64, b64, out2)
74
+
75
+ demo.queue()
76
+ demo.launch()