HalogenFlo commited on
Commit
5036d72
·
verified ·
1 Parent(s): 1cd1273

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -13
app.py CHANGED
@@ -82,6 +82,40 @@ def predict_character(image):
82
  print(f"Error predicting character: {e}")
83
  return {}
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
  print("Loading llm processor and model...")
@@ -165,20 +199,41 @@ with gr.Blocks(css=custom_css, title="TIC AI Hub") as demo:
165
  # Tab 2: Handwriting Recognition
166
  with gr.TabItem("Handwriting Recognition"):
167
  gr.Markdown("### Recognize handwritten characters and digits using ViT")
168
- with gr.Row():
169
- with gr.Column():
170
- img_input = gr.Sketchpad(
171
- label="Draw a character on the sketchpad below",
172
- type="pil"
173
- )
174
  with gr.Row():
175
- clear_btn_h = gr.Button("Clear", elem_classes="secondary-btn")
176
- submit_btn_h = gr.Button("Predict", elem_classes="primary-btn")
177
- with gr.Column():
178
- lbl_handwrite = gr.Label(label="Top 5 Predicted Characters", num_top_classes=5)
179
-
180
- submit_btn_h.click(fn=predict_character, inputs=img_input, outputs=lbl_handwrite)
181
- clear_btn_h.click(fn=lambda: (None, None), outputs=[img_input, lbl_handwrite])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  # Tab 3: Chatbot
184
  with gr.TabItem("AI Chatbot"):
 
82
  print(f"Error predicting character: {e}")
83
  return {}
84
 
85
+ def predict_character_upload(image):
86
+ """Dự đoán ký tự từ ảnh upload (PIL Image trực tiếp, không qua Sketchpad dict)."""
87
+ if image is None:
88
+ return {}
89
+
90
+ try:
91
+ # Ảnh upload là PIL Image trực tiếp
92
+ pil_image = image if isinstance(image, Image.Image) else Image.open(image)
93
+
94
+ # Chuyển sang grayscale
95
+ gray_image = pil_image.convert("L")
96
+
97
+ # EMNIST: nền đen, nét trắng → invert nếu nền sáng
98
+ avg_color = np.mean(np.array(gray_image))
99
+ if avg_color > 127:
100
+ gray_image = ImageOps.invert(gray_image)
101
+
102
+ # Chuyển RGB và resize cho ViT
103
+ rgb_image = gray_image.convert("RGB").resize((224, 224))
104
+
105
+ inputs = process(images=rgb_image, return_tensors="pt").to(device)
106
+ with torch.no_grad():
107
+ outputs = emnist_model(**inputs)
108
+
109
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
110
+ topk_probs, topk_idx = torch.topk(probs, 5)
111
+
112
+ return {
113
+ emnist_labels[int(idx.item())]: float(val.item())
114
+ for val, idx in zip(topk_probs, topk_idx)
115
+ }
116
+ except Exception as e:
117
+ print(f"Error predicting from uploaded image: {e}")
118
+ return {}
119
 
120
 
121
  print("Loading llm processor and model...")
 
199
  # Tab 2: Handwriting Recognition
200
  with gr.TabItem("Handwriting Recognition"):
201
  gr.Markdown("### Recognize handwritten characters and digits using ViT")
202
+ with gr.Tabs():
203
+ # Sub-tab: Vẽ tay
204
+ with gr.TabItem("✏️ Draw"):
 
 
 
205
  with gr.Row():
206
+ with gr.Column():
207
+ img_input = gr.Sketchpad(
208
+ label="Draw a character on the sketchpad below",
209
+ type="pil"
210
+ )
211
+ with gr.Row():
212
+ clear_btn_h = gr.Button("Clear", elem_classes="secondary-btn")
213
+ submit_btn_h = gr.Button("Predict", elem_classes="primary-btn")
214
+ with gr.Column():
215
+ lbl_handwrite = gr.Label(label="Top 5 Predicted Characters", num_top_classes=5)
216
+
217
+ submit_btn_h.click(fn=predict_character, inputs=img_input, outputs=lbl_handwrite)
218
+ clear_btn_h.click(fn=lambda: (None, None), outputs=[img_input, lbl_handwrite])
219
+
220
+ # Sub-tab: Upload ảnh
221
+ with gr.TabItem("📷 Upload Image"):
222
+ with gr.Row():
223
+ with gr.Column():
224
+ img_upload = gr.Image(
225
+ label="Upload an image of a handwritten character",
226
+ type="pil",
227
+ sources=["upload", "clipboard"]
228
+ )
229
+ with gr.Row():
230
+ clear_btn_u = gr.Button("Clear", elem_classes="secondary-btn")
231
+ submit_btn_u = gr.Button("Predict", elem_classes="primary-btn")
232
+ with gr.Column():
233
+ lbl_upload = gr.Label(label="Top 5 Predicted Characters", num_top_classes=5)
234
+
235
+ submit_btn_u.click(fn=predict_character_upload, inputs=img_upload, outputs=lbl_upload)
236
+ clear_btn_u.click(fn=lambda: (None, None), outputs=[img_upload, lbl_upload])
237
 
238
  # Tab 3: Chatbot
239
  with gr.TabItem("AI Chatbot"):