KarthiEz commited on
Commit
542e67d
·
verified ·
1 Parent(s): 4dfba61

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -147
app.py CHANGED
@@ -83,17 +83,33 @@ def _gc():
83
  torch.cuda.empty_cache()
84
 
85
 
86
- def build_hunyuan_messages(history_messages, latest_user_text, image_path):
87
  """
88
- history_messages: list of {'role', 'content'} for past turns
89
- latest_user_text: str (current user message)
90
- image_path: filepath of last uploaded image (or None)
91
-
92
- Returns: new list of messages including current user turn
93
  """
94
- messages = copy.deepcopy(history_messages)
95
 
96
- # Build content for the current user turn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  content = []
98
  if image_path:
99
  content.append(
@@ -105,71 +121,26 @@ def build_hunyuan_messages(history_messages, latest_user_text, image_path):
105
  if latest_user_text:
106
  content.append({"type": "text", "text": latest_user_text})
107
 
108
- if not content:
109
- # No text, no image → don't add a turn
110
- return messages
111
 
112
- messages.append({"role": "user", "content": content})
113
  return messages
114
 
115
 
116
- def rebuild_chat_display(history_messages):
117
- """
118
- Convert internal Hunyuan-like messages into classic
119
- Chatbot display: list of (user_str, assistant_str) tuples.
120
- """
121
- chat = []
122
- last_user = None
123
-
124
- for msg in history_messages:
125
- role = msg.get("role")
126
- content = msg.get("content", [])
127
-
128
- if role == "user":
129
- # Collect only text pieces for display
130
- if isinstance(content, list):
131
- text_parts = [
132
- c.get("text", "")
133
- for c in content
134
- if isinstance(c, dict) and c.get("type") == "text"
135
- ]
136
- user_text = " ".join(tp for tp in text_parts if tp.strip())
137
- if not user_text:
138
- user_text = "[image]"
139
- else:
140
- user_text = str(content)
141
- last_user = user_text
142
-
143
- elif role == "assistant":
144
- if isinstance(content, list):
145
- text_parts = [
146
- c.get("text", "")
147
- for c in content
148
- if isinstance(c, dict) and c.get("type") == "text"
149
- ]
150
- bot_text = " ".join(tp for tp in text_parts if tp.strip())
151
- else:
152
- bot_text = str(content)
153
-
154
- if last_user is None:
155
- last_user = ""
156
- chat.append((last_user, bot_text))
157
- last_user = None
158
-
159
- return chat
160
-
161
-
162
  def main():
163
  args = _get_args()
164
  model, processor = _load_model_processor(args)
165
 
 
 
 
166
  @spaces.GPU(duration=120)
167
  def call_local_model(hy_messages):
168
  import time
169
 
170
  start_time = time.time()
171
 
172
- # Hunyuan expects list[list[message]]
173
  convs = [hy_messages]
174
 
175
  texts = [
@@ -192,7 +163,7 @@ def main():
192
  device = "cuda" if torch.cuda.is_available() else "cpu"
193
  inputs = inputs.to(device)
194
 
195
- max_new_tokens = 2048
196
  with torch.no_grad():
197
  if device == "cuda":
198
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
@@ -221,103 +192,62 @@ def main():
221
  return text
222
 
223
  # -------------------------
224
- # Gradio UI (Blocks)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  # -------------------------
226
  with gr.Blocks() as demo:
227
  gr.Markdown(
228
  "# HunyuanOCR\n"
229
- "*Upload an image (invoice, document, receipt, notice, etc.) and ask OCR questions.*"
230
  )
231
 
232
  with gr.Row():
233
  with gr.Column(scale=2):
234
- chatbot = gr.Chatbot(
235
- label="HunyuanOCR Chat",
236
- height=600,
237
- type="messages", # ✅ explicitly say we are using messages format
238
- )
239
- user_input = gr.Textbox(
240
- label="Your question",
241
- placeholder="Example: Detect and recognize all text in this image.",
242
- lines=2,
243
- )
244
- with gr.Row():
245
- send_btn = gr.Button("Send", variant="primary")
246
- clear_btn = gr.Button("Clear Chat")
247
-
248
- with gr.Column(scale=1):
249
- image_input = gr.Image(
250
- label="Upload image",
251
- type="filepath",
252
- )
253
- gr.Markdown(
254
- "Tips:\n"
255
- "- Use clear, high-resolution scans.\n"
256
- "- Supported: JPG, PNG.\n"
257
- "- You can reuse the same image for multiple questions."
258
  )
259
 
260
- # Internal states:
261
- # - history_messages: list of Hunyuan-style messages
262
- # - image_state: latest uploaded image path
263
- history_messages = gr.State([])
264
- image_state = gr.State(None)
265
-
266
- # Handler: on image upload → just store path in state
267
- def on_image_upload(img_path):
268
- # img_path is already a filepath (type='filepath')
269
- return img_path
270
-
271
- image_input.upload(
272
- on_image_upload,
273
- inputs=image_input,
274
- outputs=image_state,
275
- )
276
-
277
- # Handler: main send logic
278
- def on_send(text, chat_value, history_msgs, img_path):
279
- # If nothing to do, return unchanged
280
- if (not text or not text.strip()) and not img_path:
281
- return chat_value, history_msgs, ""
282
-
283
- # 1) Build messages with new user turn
284
- messages = build_hunyuan_messages(history_msgs, text.strip(), img_path)
285
-
286
- # 2) Call model
287
- answer = call_local_model(messages)
288
-
289
- # 3) Append assistant turn to history
290
- messages.append(
291
- {
292
- "role": "assistant",
293
- "content": [{"type": "text", "text": answer}],
294
- }
295
- )
296
-
297
- # 4) Return updated messages both to Chatbot and history_state
298
- # Chatbot (type="messages") expects this format directly
299
- return messages, messages, ""
300
-
301
- send_btn.click(
302
- on_send,
303
- inputs=[user_input, chatbot, history_messages, image_state],
304
- outputs=[chatbot, history_messages, user_input],
305
- )
306
- user_input.submit(
307
- on_send,
308
- inputs=[user_input, chatbot, history_messages, image_state],
309
- outputs=[chatbot, history_messages, user_input],
310
- )
311
-
312
- # Clear everything
313
- def on_clear():
314
- _gc()
315
- return [], [], None, None
316
-
317
- clear_btn.click(
318
- on_clear,
319
- inputs=[],
320
- outputs=[chatbot, history_messages, image_input, image_state],
321
  )
322
 
323
  demo.queue().launch(
 
83
  torch.cuda.empty_cache()
84
 
85
 
86
+ def build_hunyuan_messages_from_history(history, image_path, latest_user_text):
87
  """
88
+ history: list of [user_text, assistant_text] pairs from ChatInterface
89
+ image_path: current uploaded image file path (or None)
90
+ latest_user_text: current user message (str)
91
+ Returns: list[{"role": ..., "content": [...]}] for HunYuan
 
92
  """
93
+ messages = []
94
 
95
+ # 1) Past turns (only text – image reused only for current turn)
96
+ for user, assistant in history:
97
+ # user
98
+ messages.append(
99
+ {
100
+ "role": "user",
101
+ "content": [{"type": "text", "text": user}],
102
+ }
103
+ )
104
+ # assistant
105
+ messages.append(
106
+ {
107
+ "role": "assistant",
108
+ "content": [{"type": "text", "text": assistant}],
109
+ }
110
+ )
111
+
112
+ # 2) Current user turn (image + text)
113
  content = []
114
  if image_path:
115
  content.append(
 
121
  if latest_user_text:
122
  content.append({"type": "text", "text": latest_user_text})
123
 
124
+ if content:
125
+ messages.append({"role": "user", "content": content})
 
126
 
 
127
  return messages
128
 
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  def main():
131
  args = _get_args()
132
  model, processor = _load_model_processor(args)
133
 
134
+ # -------------------------
135
+ # Core model call
136
+ # -------------------------
137
  @spaces.GPU(duration=120)
138
  def call_local_model(hy_messages):
139
  import time
140
 
141
  start_time = time.time()
142
 
143
+ # HunYuan expects list[list[message]]
144
  convs = [hy_messages]
145
 
146
  texts = [
 
163
  device = "cuda" if torch.cuda.is_available() else "cpu"
164
  inputs = inputs.to(device)
165
 
166
+ max_new_tokens = 512 # keep this smaller on CPU
167
  with torch.no_grad():
168
  if device == "cuda":
169
  with torch.cuda.amp.autocast(dtype=torch.bfloat16):
 
192
  return text
193
 
194
  # -------------------------
195
+ # Chat handler for ChatInterface
196
+ # -------------------------
197
+ def ocr_chat(message, history, image_path):
198
+ """
199
+ message: current user text (str)
200
+ history: list[[user, assistant], ...]
201
+ image_path: filepath from Image component
202
+ """
203
+ message = (message or "").strip()
204
+
205
+ if not message and not image_path:
206
+ return "Please upload an image and/or type a question."
207
+
208
+ hy_messages = build_hunyuan_messages_from_history(
209
+ history or [], image_path, message
210
+ )
211
+ answer = call_local_model(hy_messages)
212
+ return answer
213
+
214
+ # -------------------------
215
+ # UI: ChatInterface + image
216
  # -------------------------
217
  with gr.Blocks() as demo:
218
  gr.Markdown(
219
  "# HunyuanOCR\n"
220
+ "Upload an image (invoice, document, receipt, notice, etc.) and ask OCR questions."
221
  )
222
 
223
  with gr.Row():
224
  with gr.Column(scale=2):
225
+ chat = gr.ChatInterface(
226
+ fn=ocr_chat,
227
+ chatbot=gr.Chatbot(
228
+ label="HunyuanOCR Chat",
229
+ height=600,
230
+ ),
231
+ textbox=gr.Textbox(
232
+ label="Your question",
233
+ placeholder="Example: Detect and recognize all text in this image.",
234
+ lines=2,
235
+ ),
236
+ additional_inputs=[
237
+ gr.Image(
238
+ label="Upload image",
239
+ type="filepath",
240
+ )
241
+ ],
242
+ title=None,
243
+ description=None,
 
 
 
 
 
244
  )
245
 
246
+ gr.Markdown(
247
+ "Tips:\n"
248
+ "- Use clear, high-resolution scans.\n"
249
+ "- Supported: JPG, PNG.\n"
250
+ "- You can reuse the same image for multiple questions."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  )
252
 
253
  demo.queue().launch(