Acetde commited on
Commit
a0db112
·
verified ·
1 Parent(s): 9606b04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -35
app.py CHANGED
@@ -4,67 +4,72 @@ import onnxruntime as rt
4
  from torchvision import transforms as T
5
  from PIL import Image
6
  from tokenizer_base import Tokenizer
7
- import pathlib
8
- import os
9
- import gradio as gr
10
- from huggingface_hub import Repository
11
-
12
-
13
 
 
14
  model_file = "captcha.onnx"
15
  img_size = (32,128)
16
  charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
17
  tokenizer_base = Tokenizer(charset)
18
 
 
19
  def get_transform(img_size):
20
- transforms = []
21
- transforms.extend([
22
- T.Resize(img_size, T.InterpolationMode.BICUBIC),
23
- T.ToTensor(),
24
- T.Normalize(0.5, 0.5)
25
- ])
26
- return T.Compose(transforms)
27
 
28
  def to_numpy(tensor):
29
  return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
30
 
 
31
  def initialize_model(model_file):
32
  transform = get_transform(img_size)
33
- # Onnx model loading
34
  onnx_model = onnx.load(model_file)
35
  onnx.checker.check_model(onnx_model)
36
  ort_session = rt.InferenceSession(model_file)
37
- return transform,ort_session
 
 
 
38
 
 
39
  def get_text(img_org):
40
- # img_org = Image.open(image_path)
41
- # Preprocess. Model expects a batch of images with shape: (B, C, H, W)
42
  x = transform(img_org.convert('RGB')).unsqueeze(0)
43
-
44
- # compute ONNX Runtime output prediction
45
  ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
46
  logits = ort_session.run(None, ort_inputs)[0]
47
  probs = torch.tensor(logits).softmax(-1)
48
  preds, probs = tokenizer_base.decode(probs)
49
  preds = preds[0]
50
- print(preds)
51
  return preds
52
 
53
- transform,ort_session = initialize_model(model_file=model_file)
 
 
 
 
 
54
 
55
- gr.Interface(
56
- get_text,
57
- inputs=gr.Image(type="pil"),
58
- outputs=gr.Textbox(),
59
- title="Text Captcha Reader",
60
- examples=["8000.png","11JW29.png","2a8486.jpg","2nbcx.png",
61
- "000679.png","000HU.png","00Uga.png.jpg","00bAQwhAZU.jpg",
62
- "00h57kYf.jpg","0EoHdtVb.png","0JS21.png","0p98z.png","10010.png"]
63
- ).launch()
 
 
64
 
65
- # if __name__ == "__main__":
66
- # image_path = "8000.png"
67
- # preds,probs = get_text(image_path)
68
- # print(preds[0])
69
-
70
 
 
4
  from torchvision import transforms as T
5
  from PIL import Image
6
  from tokenizer_base import Tokenizer
7
+ import io
8
+ import base64
9
+ from fastapi import FastAPI, UploadFile, File
10
+ from pydantic import BaseModel
11
+ import numpy as np
 
12
 
13
+ # Параметры модели
14
  model_file = "captcha.onnx"
15
  img_size = (32,128)
16
  charset = r"0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"
17
  tokenizer_base = Tokenizer(charset)
18
 
19
+ # Инициализация преобразования изображений
20
  def get_transform(img_size):
21
+ transforms = []
22
+ transforms.extend([
23
+ T.Resize(img_size, T.InterpolationMode.BICUBIC),
24
+ T.ToTensor(),
25
+ T.Normalize(0.5, 0.5)
26
+ ])
27
+ return T.Compose(transforms)
28
 
29
  def to_numpy(tensor):
30
  return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
31
 
32
+ # Загрузка модели
33
  def initialize_model(model_file):
34
  transform = get_transform(img_size)
 
35
  onnx_model = onnx.load(model_file)
36
  onnx.checker.check_model(onnx_model)
37
  ort_session = rt.InferenceSession(model_file)
38
+ return transform, ort_session
39
+
40
+ # Инициализация модели
41
+ transform, ort_session = initialize_model(model_file=model_file)
42
 
43
+ # Функция для получения текста с изображения
44
  def get_text(img_org):
 
 
45
  x = transform(img_org.convert('RGB')).unsqueeze(0)
 
 
46
  ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
47
  logits = ort_session.run(None, ort_inputs)[0]
48
  probs = torch.tensor(logits).softmax(-1)
49
  preds, probs = tokenizer_base.decode(probs)
50
  preds = preds[0]
 
51
  return preds
52
 
53
+ # FastAPI приложение
54
+ app = FastAPI()
55
+
56
+ # Модель данных для работы с запросом
57
+ class ImageData(BaseModel):
58
+ image: str
59
 
60
+ # Endpoint для обработки изображения в формате Base64
61
+ @app.post("/predict/")
62
+ async def predict(data: ImageData):
63
+ try:
64
+ # Декодируем Base64
65
+ img_data = base64.b64decode(data.image.split(",")[1])
66
+ img = Image.open(io.BytesIO(img_data))
67
+ result = get_text(img)
68
+ return {"result": result}
69
+ except Exception as e:
70
+ return {"error": str(e)}
71
 
72
+ # Запуск сервера (для локального тестирования)
73
+ # Если вы хотите запустить сервер на хосте, используйте команду:
74
+ # uvicorn filename:app --reload
 
 
75