longjava2024 commited on
Commit
d8afded
·
verified ·
1 Parent(s): c67c2e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +244 -26
app.py CHANGED
@@ -1,41 +1,259 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
4
  import torch
 
 
 
 
 
 
 
 
 
5
 
6
- app = FastAPI()
7
 
8
  MODEL_NAME = "5CD-AI/Vintern-1B-v2"
 
 
9
 
10
- print("Loading tokenizer...")
11
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
 
13
- print("Loading model (INT4, CPU)...")
14
- model = AutoModelForCausalLM.from_pretrained(
 
 
 
 
 
15
  MODEL_NAME,
16
- load_in_4bit=True,
17
- device_map="cpu",
18
- torch_dtype=torch.float16
 
 
 
 
 
 
 
 
19
  )
20
 
21
- class InferRequest(BaseModel):
22
- text: str
23
 
24
- @app.post("/infer")
25
- def infer(req: InferRequest):
26
- inputs = tokenizer(
27
- req.text,
28
- return_tensors="pt",
29
- truncation=True,
30
- max_length=512
 
 
 
 
 
 
 
 
 
31
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  with torch.no_grad():
34
- output = model.generate(
35
- **inputs,
36
- max_new_tokens=256,
37
- do_sample=False
 
38
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- result = tokenizer.decode(output[0], skip_special_tokens=True)
41
- return {"result": result}
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import ast
4
+ import re
5
+ from io import BytesIO
6
+
7
  import torch
8
+ import torchvision.transforms as T
9
+ from PIL import Image
10
+ from torchvision.transforms.functional import InterpolationMode
11
+ from fastapi import FastAPI, HTTPException
12
+ from pydantic import BaseModel
13
+ from transformers import AutoModel, AutoTokenizer
14
+
15
+
16
+ app = FastAPI(title="CCCD OCR with Vintern-1B-v2")
17
 
 
18
 
19
  MODEL_NAME = "5CD-AI/Vintern-1B-v2"
20
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
+ DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
22
 
 
 
23
 
24
+ print(f"Loading model `{MODEL_NAME}` on {DEVICE} ...")
25
+ tokenizer = AutoTokenizer.from_pretrained(
26
+ MODEL_NAME,
27
+ trust_remote_code=True,
28
+ use_fast=False,
29
+ )
30
+ model = AutoModel.from_pretrained(
31
  MODEL_NAME,
32
+ torch_dtype=DTYPE,
33
+ low_cpu_mem_usage=True,
34
+ trust_remote_code=True,
35
+ )
36
+ model.eval().to(DEVICE)
37
+
38
+ generation_config = dict(
39
+ max_new_tokens=512,
40
+ do_sample=False,
41
+ num_beams=3,
42
+ repetition_penalty=3.5,
43
  )
44
 
 
 
45
 
46
+ # =========================
47
+ # Image preprocessing (from notebook)
48
+ # =========================
49
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
50
+ IMAGENET_STD = (0.229, 0.224, 0.225)
51
+
52
+
53
+ def build_transform(input_size: int):
54
+ mean, std = IMAGENET_MEAN, IMAGENET_STD
55
+ transform = T.Compose(
56
+ [
57
+ T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
58
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
59
+ T.ToTensor(),
60
+ T.Normalize(mean=mean, std=std),
61
+ ]
62
  )
63
+ return transform
64
+
65
+
66
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
67
+ best_ratio_diff = float("inf")
68
+ best_ratio = (1, 1)
69
+ area = width * height
70
+ for ratio in target_ratios:
71
+ target_aspect_ratio = ratio[0] / ratio[1]
72
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
73
+ if ratio_diff < best_ratio_diff:
74
+ best_ratio_diff = ratio_diff
75
+ best_ratio = ratio
76
+ elif ratio_diff == best_ratio_diff:
77
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
78
+ best_ratio = ratio
79
+ return best_ratio
80
+
81
+
82
+ def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
83
+ orig_width, orig_height = image.size
84
+ aspect_ratio = orig_width / orig_height
85
+
86
+ target_ratios = set(
87
+ (i, j)
88
+ for n in range(min_num, max_num + 1)
89
+ for i in range(1, n + 1)
90
+ for j in range(1, n + 1)
91
+ if i * j <= max_num and i * j >= min_num
92
+ )
93
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
94
+
95
+ target_aspect_ratio = find_closest_aspect_ratio(
96
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size
97
+ )
98
+
99
+ target_width = image_size * target_aspect_ratio[0]
100
+ target_height = image_size * target_aspect_ratio[1]
101
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
102
+
103
+ resized_img = image.resize((target_width, target_height))
104
+ processed_images = []
105
+ for i in range(blocks):
106
+ box = (
107
+ (i % (target_width // image_size)) * image_size,
108
+ (i // (target_width // image_size)) * image_size,
109
+ ((i % (target_width // image_size)) + 1) * image_size,
110
+ ((i // (target_width // image_size)) + 1) * image_size,
111
+ )
112
+ split_img = resized_img.crop(box)
113
+ processed_images.append(split_img)
114
+ assert len(processed_images) == blocks
115
+ if use_thumbnail and len(processed_images) != 1:
116
+ thumbnail_img = image.resize((image_size, image_size))
117
+ processed_images.append(thumbnail_img)
118
+ return processed_images
119
+
120
 
121
+ def load_image_from_base64(base64_string: str, input_size=448, max_num=12):
122
+ if base64_string.startswith("data:image"):
123
+ base64_string = base64_string.split(",", 1)[1]
124
+
125
+ image_data = base64.b64decode(base64_string)
126
+ image = Image.open(BytesIO(image_data)).convert("RGB")
127
+ transform = build_transform(input_size=input_size)
128
+ images = dynamic_preprocess(
129
+ image, image_size=input_size, use_thumbnail=True, max_num=max_num
130
+ )
131
+ pixel_values = [transform(img) for img in images]
132
+ pixel_values = torch.stack(pixel_values)
133
+ return pixel_values
134
+
135
+
136
+ # =========================
137
+ # Prompt & helpers
138
+ # =========================
139
+ PROMPT = """<image>
140
+ Bạn là hệ thống OCR + trích xuất dữ liệu từ ảnh Căn cước công dân (CCCD) Việt Nam.
141
+ Nhiệm vụ: đọc đúng chữ trên thẻ và trả về CHỈ 1 đối tượng JSON theo schema quy định.
142
+
143
+ QUY TẮC BẮT BUỘC:
144
+ 1) Chỉ trả về JSON thuần (không markdown, không giải thích, không thêm ký tự nào ngoài JSON).
145
+ 2) Chỉ được có đúng 5 khóa sau (đúng chính tả, đúng chữ thường, có dấu gạch dưới):
146
+ - "so_no"
147
+ - "ho_va_ten"
148
+ - "ngay_sinh"
149
+ - "que_quan"
150
+ - "noi_thuong_tru"
151
+ Không được thêm bất kỳ khóa nào khác.
152
+ 3) Mapping trường (lấy theo NHÃN in trên thẻ, không lấy từ QR):
153
+ - so_no: lấy giá trị ngay sau nhãn "Số / No." (hoặc "Số/No.").
154
+ - ho_va_ten: lấy giá trị ngay sau nhãn "Họ và tên / Full name".
155
+ - ngay_sinh: lấy giá trị ngay sau nhãn "Ngày sinh / Date of birth"; nếu có định dạng dd/mm/yyyy thì giữ đúng dd/mm/yyyy.
156
+ - que_quan: lấy giá trị ngay sau nhãn "Quê quán / Place of origin".
157
+ - noi_thuong_tru: lấy giá trị ngay sau nhãn "Nơi thường trú / Place of residence".
158
+ 4) Nếu trường nào không đọc được rõ/chắc chắn: đặt null. Không được suy đoán.
159
+ 5) Chuẩn hoá: trim khoảng trắng đầu/cuối; giữ nguyên dấu tiếng Việt và chữ hoa/thường như trong ảnh.
160
+
161
+ CHỈ TRẢ VỀ THEO MẪU JSON NÀY:
162
+ {
163
+ "so_no": "... hoặc null",
164
+ "ho_va_ten": "... hoặc null",
165
+ "ngay_sinh": "... hoặc null",
166
+ "que_quan": "... hoặc null",
167
+ "noi_thuong_tru": "... hoặc null"
168
+ }
169
+ """
170
+
171
+
172
+ def parse_response_to_json(response_text: str):
173
+ if not response_text:
174
+ return None
175
+
176
+ s = response_text.strip()
177
+
178
+ if s.startswith('"') and s.endswith('"'):
179
+ s = s[1:-1].replace('\\"', '"')
180
+
181
+ try:
182
+ obj = json.loads(s)
183
+ if isinstance(obj, dict):
184
+ return obj
185
+ except json.JSONDecodeError:
186
+ pass
187
+
188
+ try:
189
+ obj = ast.literal_eval(s)
190
+ if isinstance(obj, dict):
191
+ return obj
192
+ except (ValueError, SyntaxError):
193
+ pass
194
+
195
+ json_pattern = r"\{[\s\S]*\}"
196
+ m = re.search(json_pattern, s)
197
+ if m:
198
+ chunk = m.group(0).strip()
199
+ try:
200
+ obj = ast.literal_eval(chunk)
201
+ if isinstance(obj, dict):
202
+ return obj
203
+ except Exception:
204
+ pass
205
+ try:
206
+ chunk2 = chunk.replace("'", '"')
207
+ obj = json.loads(chunk2)
208
+ if isinstance(obj, dict):
209
+ return obj
210
+ except Exception:
211
+ pass
212
+
213
+ return {"text": response_text}
214
+
215
+
216
+ def normalize_base64(image_base64: str) -> str:
217
+ if not image_base64:
218
+ return image_base64
219
+ image_base64 = image_base64.strip()
220
+ if image_base64.startswith("data:"):
221
+ parts = image_base64.split(",", 1)
222
+ if len(parts) == 2:
223
+ return parts[1]
224
+ return image_base64
225
+
226
+
227
+ def ocr_by_llm(image_base64: str, prompt: str) -> str:
228
+ pixel_values = load_image_from_base64(image_base64, max_num=6)
229
+ if DEVICE == "cuda":
230
+ pixel_values = pixel_values.to(dtype=torch.bfloat16, device=DEVICE)
231
+ else:
232
+ pixel_values = pixel_values.to(dtype=torch.float32, device=DEVICE)
233
  with torch.no_grad():
234
+ response_message = model.chat(
235
+ tokenizer,
236
+ pixel_values,
237
+ prompt,
238
+ generation_config,
239
  )
240
+ del pixel_values
241
+ return response_message
242
+
243
+
244
+ class OCRRequest(BaseModel):
245
+ image_base64: str
246
+
247
+
248
+ @app.post("/ocr")
249
+ def ocr_endpoint(req: OCRRequest):
250
+ image_base64 = normalize_base64(req.image_base64)
251
+ if not image_base64:
252
+ raise HTTPException(status_code=400, detail="image_base64 is required")
253
 
254
+ try:
255
+ response_message = ocr_by_llm(image_base64, PROMPT)
256
+ parsed = parse_response_to_json(response_message)
257
+ return {"response_message": parsed}
258
+ except Exception as e:
259
+ raise HTTPException(status_code=500, detail=str(e))