Sangjun2 commited on
Commit
fb56a77
ยท
verified ยท
1 Parent(s): ca427b1

new_new_new_vaiv_app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -983
app.py CHANGED
@@ -20,6 +20,9 @@ import time
20
  import logging
21
  import subprocess
22
  import spaces
 
 
 
23
 
24
  # Git LFS pull ๋ช…๋ น์–ด ์‹คํ–‰
25
  result = subprocess.run(['git', 'lfs', 'pull'], capture_output=True, text=True)
@@ -36,55 +39,26 @@ logger = logging.getLogger()
36
  warnings.filterwarnings('ignore')
37
  MAX_PATCHES = 512
38
  # Load the models and processor
39
- #device = torch.device("cpu")
40
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
 
42
  # Paths to the models
43
- ko_deplot_model_path = './deplot_model_ver_kor_24.7.25_refinetuning_epoch3.bin'
44
- aihub_deplot_model_path='./deplot_k.pt'
45
- t5_model_path = './ke_t5.pt'
46
 
47
  # Load first model ko-deplot
48
-
49
  def load_model1():
50
  processor1 = Pix2StructProcessor.from_pretrained('nuua/ko-deplot')
51
  model1 = Pix2StructForConditionalGeneration.from_pretrained('nuua/ko-deplot')
52
  model1.load_state_dict(torch.load(ko_deplot_model_path, map_location="cpu"))
53
  model1.to(torch.device("cuda"))
54
- return processor1,model1
55
-
56
- processor1,model1=load_model1()
57
-
58
- # Load second model aihub-deplot
59
-
60
- def load_model2():
61
- processor2 = AutoProcessor.from_pretrained("ybelkada/pix2struct-base")
62
- model2 = Pix2StructForConditionalGeneration.from_pretrained("ybelkada/pix2struct-base")
63
- model2.load_state_dict(torch.load(aihub_deplot_model_path, map_location="cpu"))
64
- model2.to(torch.device("cuda"))
65
- return processor2,model2
66
-
67
- processor2,model2=load_model2()
68
 
 
69
 
70
- #Load third model unichart
71
-
72
- def load_model3():
73
- unichart_model_path = "./unichart4/chartqa-checkpoint-epoch=2-161952"
74
- model3 = VisionEncoderDecoderModel.from_pretrained(unichart_model_path)
75
- processor3 = DonutProcessor.from_pretrained(unichart_model_path)
76
- model3.to(torch.device("cuda"))
77
- return processor3,model3
78
-
79
- processor3,model3=load_model3()
80
-
81
- #ko-deplot ์ถ”๋ก ํ•จ์ˆ˜
82
  # Function to format output
83
  def format_output(prediction):
84
  return prediction.replace('<0x0A>', '\n')
85
 
86
- # First model prediction ko-deplot
87
- @spaces.GPU(enable_queue=True,duration=100)
88
  def predict_model1(image):
89
  images = [image]
90
  inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
@@ -98,1003 +72,228 @@ def predict_model1(image):
98
  formatted_output = format_output(outputs[0])
99
  return formatted_output
100
 
101
-
102
- def replace_unk(text):
103
- # 1. '์ œ๋ชฉ:', '์œ ํ˜•:' ๊ธ€์ž ์•ž์— ์žˆ๋Š” <unk>๋Š” \n๋กœ ๋ฐ”๊ฟˆ
104
- text = re.sub(r'<unk>(?=์ œ๋ชฉ:|์œ ํ˜•:)', '\n', text)
105
- # 2. '์„ธ๋กœ ' ๋˜๋Š” '๊ฐ€๋กœ '์™€ '๋Œ€ํ˜•' ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ ""๋กœ ๋ฐ”๊ฟˆ
106
- text = re.sub(r'(?<=์„ธ๋กœ |๊ฐ€๋กœ )<unk>(?=๋Œ€ํ˜•)', '', text)
107
- # 3. ์ˆซ์ž์™€ ํ…์ŠคํŠธ ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
108
- text = re.sub(r'(\d)<unk>([^\d])', r'\1\n\2', text)
109
- # 4. %, ์›, ๊ฑด, ๋ช… ๋’ค์— ๋‚˜์˜ค๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
110
- text = re.sub(r'(?<=[%์›๊ฑด๋ช…\)])<unk>', '\n', text)
111
- # 5. ์ˆซ์ž์™€ ์ˆซ์ž ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
112
- text = re.sub(r'(\d)<unk>(\d)', r'\1\n\2', text)
113
- # 6. 'ํ˜•'์ด๋ผ๋Š” ๊ธ€์ž์™€ ' |' ์‚ฌ์ด์— ์žˆ๋Š” <unk>๋ฅผ \n๋กœ ๋ฐ”๊ฟˆ
114
- text = re.sub(r'ํ˜•<unk>(?= \|)', 'ํ˜•\n', text)
115
- # 7. ๋‚˜๋จธ์ง€ <unk>๋ฅผ ๋ชจ๋‘ ""๋กœ ๋ฐ”๊ฟˆ
116
- text = text.replace('<unk>', '')
117
- return text
118
-
119
-
120
- @spaces.GPU(enable_queue=True,duration=100)
121
- def predict_model3(image):
122
- image=image.convert("RGB")
123
- input_prompt = "<extract_data_table> <s_answer>"
124
- decoder_input_ids = processor3.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
125
- pixel_values = processor3(image, return_tensors="pt").pixel_values
126
- outputs = model3.generate(
127
- pixel_values.to(device),
128
- decoder_input_ids=decoder_input_ids.to(device),
129
- max_length=model3.decoder.config.max_position_embeddings,
130
- early_stopping=True,
131
- pad_token_id=processor3.tokenizer.pad_token_id,
132
- eos_token_id=processor3.tokenizer.eos_token_id,
133
- use_cache=True,
134
- num_beams=4,
135
- bad_words_ids=[[processor3.tokenizer.unk_token_id]],
136
- return_dict_in_generate=True,
137
- )
138
- sequence = processor3.batch_decode(outputs.sequences)[0]
139
- sequence = sequence.replace(processor3.tokenizer.eos_token, "").replace(processor3.tokenizer.pad_token, "")
140
- sequence = sequence.split("<s_answer>")[-1].strip()
141
-
142
- return sequence
143
- #function for converting aihub dataset labeling json file to ko-deplot data table
144
- def process_json_file(input_file):
145
- with open(input_file, 'r', encoding='utf-8') as file:
146
- data = json.load(file)
147
-
148
- # ํ•„์š”ํ•œ ๋ฐ์ดํ„ฐ ์ถ”์ถœ
149
- chart_type = data['metadata']['chart_sub']
150
- title = data['annotations'][0]['title']
151
- x_axis = data['annotations'][0]['axis_label']['x_axis']
152
- y_axis = data['annotations'][0]['axis_label']['y_axis']
153
- legend = data['annotations'][0]['legend']
154
- data_labels = data['annotations'][0]['data_label']
155
- is_legend = data['annotations'][0]['is_legend']
156
-
157
- # ์›ํ•˜๋Š” ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
158
- formatted_string = f"TITLE | {title} <0x0A> "
159
- if '๊ฐ€๋กœ' in chart_type:
160
- if is_legend:
161
- # ๊ฐ€๋กœ ์ฐจํŠธ ์ฒ˜๋ฆฌ
162
- formatted_string += " | ".join(legend) + " <0x0A> "
163
- for i in range(len(y_axis)):
164
- row = [y_axis[i]]
165
- for j in range(len(legend)):
166
- if i < len(data_labels[j]):
167
- row.append(str(data_labels[j][i])) # ๋ฐ์ดํ„ฐ ๊ฐ’์„ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
168
- else:
169
- row.append("") # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
170
- formatted_string += " | ".join(row) + " <0x0A> "
171
- else:
172
- # is_legend๊ฐ€ False์ธ ๊ฒฝ์šฐ
173
- for i in range(len(y_axis)):
174
- row = [y_axis[i], str(data_labels[0][i])]
175
- formatted_string += " | ".join(row) + " <0x0A> "
176
- elif chart_type == "์›ํ˜•":
177
- # ์›ํ˜• ์ฐจํŠธ ์ฒ˜๋ฆฌ
178
- if legend:
179
- used_labels = legend
180
- else:
181
- used_labels = x_axis
182
-
183
- formatted_string += " | ".join(used_labels) + " <0x0A> "
184
- row = [data_labels[0][i] for i in range(len(used_labels))]
185
- formatted_string += " | ".join(row) + " <0x0A> "
186
- elif chart_type == "ํ˜ผํ•ฉํ˜•":
187
- # ํ˜ผํ•ฉํ˜• ์ฐจํŠธ ์ฒ˜๋ฆฌ
188
- all_legends = [ann['legend'][0] for ann in data['annotations']]
189
- formatted_string += " | ".join(all_legends) + " <0x0A> "
190
-
191
- combined_data = []
192
- for i in range(len(x_axis)):
193
- row = [x_axis[i]]
194
- for ann in data['annotations']:
195
- if i < len(ann['data_label'][0]):
196
- row.append(str(ann['data_label'][0][i])) # ๋ฐ์ดํ„ฐ ๊ฐ’์„ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
197
- else:
198
- row.append("") # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
199
- combined_data.append(" | ".join(row))
200
-
201
- formatted_string += " <0x0A> ".join(combined_data) + " <0x0A> "
202
- else:
203
- # ๊ธฐํƒ€ ์ฐจํŠธ ์ฒ˜๋ฆฌ
204
- if is_legend:
205
- formatted_string += " | ".join(legend) + " <0x0A> "
206
- for i in range(len(x_axis)):
207
- row = [x_axis[i]]
208
- for j in range(len(legend)):
209
- if i < len(data_labels[j]):
210
- row.append(str(data_labels[j][i])) # ๋ฐ์ดํ„ฐ ๊ฐ’์„ ๋ฌธ์ž์—ด๋กœ ๋ณ€ํ™˜
211
- else:
212
- row.append("") # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
213
- formatted_string += " | ".join(row) + " <0x0A> "
214
- else:
215
- for i in range(len(x_axis)):
216
- if i < len(data_labels[0]):
217
- formatted_string += f"{x_axis[i]} | {str(data_labels[0][i])} <0x0A> "
218
- else:
219
- formatted_string += f"{x_axis[i]} | <0x0A> " # ๋ฐ์ดํ„ฐ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ ๋นˆ ๋ฌธ์ž์—ด ์ถ”๊ฐ€
220
-
221
- # ๋งˆ์ง€๋ง‰ "<0x0A> " ์ œ๊ฑฐ
222
- formatted_string = formatted_string[:-8]
223
- return format_output(formatted_string)
224
-
225
- def chart_data(data):
226
- datatable = []
227
- num = len(data)
228
- for n in range(num):
229
- title = data[n]['title'] if data[n]['is_title'] else ''
230
- legend = data[n]['legend'] if data[n]['is_legend'] else ''
231
- datalabel = data[n]['data_label'] if data[n]['is_datalabel'] else [0]
232
- unit = data[n]['unit'] if data[n]['is_unit'] else ''
233
- base = data[n]['base'] if data[n]['is_base'] else ''
234
- x_axis_title = data[n]['axis_title']['x_axis']
235
- y_axis_title = data[n]['axis_title']['y_axis']
236
- x_axis = data[n]['axis_label']['x_axis'] if data[n]['is_axis_label_x_axis'] else [0]
237
- y_axis = data[n]['axis_label']['y_axis'] if data[n]['is_axis_label_y_axis'] else [0]
238
-
239
- if len(legend) > 1:
240
- datalabel = np.array(datalabel).transpose().tolist()
241
-
242
- datatable.append([title, legend, datalabel, unit, base, x_axis_title, y_axis_title, x_axis, y_axis])
243
-
244
- return datatable
245
-
246
- def datatable(data, chart_type):
247
- data_table = ''
248
- num = len(data)
249
-
250
- if len(data) == 2:
251
- temp = []
252
- temp.append(f"๋Œ€์ƒ: {data[0][4]}")
253
- temp.append(f"์ œ๋ชฉ: {data[0][0]}")
254
- temp.append(f"์œ ํ˜•: {' '.join(chart_type[0:2])}")
255
- temp.append(f"{data[0][5]} | {data[0][1][0]}({data[0][3]}) | {data[1][1][0]}({data[1][3]})")
256
-
257
- x_axis = data[0][7]
258
- for idx, x in enumerate(x_axis):
259
- temp.append(f"{x} | {data[0][2][0][idx]} | {data[1][2][0][idx]}")
260
-
261
- data_table = '\n'.join(temp)
262
- else:
263
- for n in range(num):
264
- temp = []
265
-
266
- title, legend, datalabel, unit, base, x_axis_title, y_axis_title, x_axis, y_axis = data[n]
267
- legend = [element + f"({unit})" for element in legend]
268
-
269
- if len(legend) > 1:
270
- temp.append(f"๋Œ€์ƒ: {base}")
271
- temp.append(f"์ œ๋ชฉ: {title}")
272
- temp.append(f"์œ ํ˜•: {' '.join(chart_type[0:2])}")
273
- temp.append(f"{x_axis_title} | {' | '.join(legend)}")
274
-
275
- if chart_type[2] == "์›ํ˜•":
276
- datalabel = sum(datalabel, [])
277
- temp.append(f"{' | '.join([str(d) for d in datalabel])}")
278
- data_table = '\n'.join(temp)
279
- else:
280
- axis = y_axis if chart_type[2] == "๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•" else x_axis
281
- for idx, (x, d) in enumerate(zip(axis, datalabel)):
282
- temp_d = [str(e) for e in d]
283
- temp_d = " | ".join(temp_d)
284
- row = f"{x} | {temp_d}"
285
- temp.append(row)
286
- data_table = '\n'.join(temp)
287
- else:
288
- temp.append(f"๋Œ€์ƒ: {base}")
289
- temp.append(f"์ œ๋ชฉ: {title}")
290
- temp.append(f"์œ ํ˜•: {' '.join(chart_type[0:2])}")
291
- temp.append(f"{x_axis_title} | {unit}")
292
- axis = y_axis if chart_type[2] == "๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•" else x_axis
293
- datalabel = datalabel[0]
294
-
295
- for idx, x in enumerate(axis):
296
- row = f"{x} | {str(datalabel[idx])}"
297
- temp.append(row)
298
- data_table = '\n'.join(temp)
299
-
300
- return data_table
301
-
302
- #function for converting aihub dataset labeling json file to aihub-deplot data table
303
- def process_json_file2(input_file):
304
- with open(input_file, 'r', encoding='utf-8') as file:
305
- data = json.load(file)
306
- # ํ•„์š”ํ•œ ๋ฐ์ดํ„ฐ ์ถ”์ถœ
307
- chart_multi = data['metadata']['chart_multi']
308
- chart_main = data['metadata']['chart_main']
309
- chart_sub = data['metadata']['chart_sub']
310
- chart_type = [chart_multi, chart_sub, chart_main]
311
- chart_annotations = data['annotations']
312
-
313
- charData = chart_data(chart_annotations)
314
- dataTable = datatable(charData, chart_type)
315
- return dataTable
316
-
317
- # RMS
318
- def _to_float(text): # ๋‹จ์œ„ ๋–ผ๊ณ  ์ˆซ์ž๋งŒ..?
319
- try:
320
- if text.endswith("%"):
321
- # Convert percentages to floats.
322
- return float(text.rstrip("%")) / 100.0
323
- else:
324
- return float(text)
325
- except ValueError:
326
- return None
327
-
328
-
329
- def _get_relative_distance(
330
- target, prediction, theta = 1.0
331
- ):
332
- """Returns min(1, |target-prediction|/|target|)."""
333
- if not target:
334
- return int(not prediction)
335
- distance = min(abs((target - prediction) / target), 1)
336
- return distance if distance < theta else 1
337
-
338
- def anls_metric(target: str, prediction: str, theta: float = 0.5):
339
- edit_distance = editdistance.eval(target, prediction)
340
- normalize_ld = edit_distance / max(len(target), len(prediction))
341
- return 1 - normalize_ld if normalize_ld < theta else 0
342
-
343
- def _permute(values, indexes):
344
- return tuple(values[i] if i < len(values) else "" for i in indexes)
345
-
346
-
347
- @dataclasses.dataclass(frozen=True)
348
- class Table:
349
- """Helper class for the content of a markdown table."""
350
-
351
- base: Optional[str] = None
352
- title: Optional[str] = None
353
- chartType: Optional[str] = None
354
- headers: tuple[str, Ellipsis] = dataclasses.field(default_factory=tuple)
355
- rows: tuple[tuple[str, Ellipsis], Ellipsis] = dataclasses.field(default_factory=tuple)
356
-
357
- def permuted(self, indexes):
358
- """Builds a version of the table changing the column order."""
359
- return Table(
360
- base=self.base,
361
- title=self.title,
362
- chartType=self.chartType,
363
- headers=_permute(self.headers, indexes),
364
- rows=tuple(_permute(row, indexes) for row in self.rows),
365
- )
366
-
367
- def aligned(
368
- self, headers, text_theta = 0.5
369
- ):
370
- """Builds a column permutation with headers in the most correct order."""
371
- if len(headers) != len(self.headers):
372
- raise ValueError(f"Header length {headers} must match {self.headers}.")
373
- distance = []
374
- for h2 in self.headers:
375
- distance.append(
376
- [
377
- 1 - anls_metric(h1, h2, text_theta)
378
- for h1 in headers
379
- ]
380
- )
381
- cost_matrix = np.array(distance)
382
- row_ind, col_ind = optimize.linear_sum_assignment(cost_matrix)
383
- permutation = [idx for _, idx in sorted(zip(col_ind, row_ind))]
384
- score = (1 - cost_matrix)[permutation[1:], range(1, len(row_ind))].prod()
385
- return self.permuted(permutation), score
386
-
387
- def _parse_table(text, transposed = False): # ํ‘œ ์ œ๋ชฉ, ์—ด ์ด๋ฆ„, ํ–‰ ์ฐพ๊ธฐ
388
- """Builds a table from a markdown representation."""
389
- lines = text.lower().splitlines()
390
- if not lines:
391
- return Table()
392
-
393
- if lines[0].startswith("๋Œ€์ƒ: "):
394
- base = lines[0][len("๋Œ€์ƒ: ") :].strip()
395
- offset = 1 #
396
- else:
397
- base = None
398
- offset = 0
399
- if lines[1].startswith("์ œ๋ชฉ: "):
400
- title = lines[1][len("์ œ๋ชฉ: ") :].strip()
401
- offset = 2 #
402
- else:
403
- title = None
404
- offset = 1
405
- if lines[2].startswith("์œ ํ˜•: "):
406
- chartType = lines[2][len("์œ ํ˜•: ") :].strip()
407
- offset = 3 #
408
- else:
409
- chartType = None
410
-
411
- if len(lines) < offset + 1:
412
- return Table(base=base, title=title, chartType=chartType)
413
-
414
- rows = []
415
- for line in lines[offset:]:
416
- rows.append(tuple(v.strip() for v in line.split(" | ")))
417
- if transposed:
418
- rows = [tuple(row) for row in itertools.zip_longest(*rows, fillvalue="")]
419
- return Table(base=base, title=title, chartType=chartType, headers=rows[0], rows=tuple(rows[1:]))
420
-
421
- def _get_table_datapoints(table):
422
- datapoints = {}
423
- if table.base is not None:
424
- datapoints["๋Œ€์ƒ"] = table.base
425
- if table.title is not None:
426
- datapoints["์ œ๋ชฉ"] = table.title
427
- if table.chartType is not None:
428
- datapoints["์œ ํ˜•"] = table.chartType
429
- if not table.rows or len(table.headers) <= 1:
430
- return datapoints
431
- for row in table.rows:
432
- for header, cell in zip(table.headers[1:], row[1:]):
433
- #print(f"{row[0]} {header} >> {cell}")
434
- datapoints[f"{row[0]} {header}"] = cell #
435
- return datapoints
436
-
437
- def _get_datapoint_metric( #
438
- target,
439
- prediction,
440
- text_theta=0.5,
441
- number_theta=0.1,
442
- ):
443
- """Computes a metric that scores how similar two datapoint pairs are."""
444
- key_metric = anls_metric(
445
- target[0], prediction[0], text_theta
446
- )
447
- pred_float = _to_float(prediction[1]) # ์ˆซ์ž์ธ์ง€ ํ™•์ธ
448
- target_float = _to_float(target[1])
449
- if pred_float is not None and target_float:
450
- return key_metric * (
451
- 1 - _get_relative_distance(target_float, pred_float, number_theta) # ์ˆซ์ž๋ฉด ์ƒ๋Œ€์  ๊ฑฐ๋ฆฌ๊ฐ’ ๊ณ„์‚ฐ
452
- )
453
- elif target[1] == prediction[1]:
454
- return key_metric
455
- else:
456
- return key_metric * anls_metric(
457
- target[1], prediction[1], text_theta
458
- )
459
-
460
- def _table_datapoints_precision_recall_f1( # ์ฐ ๊ณ„์‚ฐ
461
- target_table,
462
- prediction_table,
463
- text_theta = 0.5,
464
- number_theta = 0.1,
465
- ):
466
- """Calculates matching similarity between two tables as dicts."""
467
- target_datapoints = list(_get_table_datapoints(target_table).items())
468
- prediction_datapoints = list(_get_table_datapoints(prediction_table).items())
469
- if not target_datapoints and not prediction_datapoints:
470
- return 1, 1, 1
471
- if not target_datapoints:
472
- return 0, 1, 0
473
- if not prediction_datapoints:
474
- return 1, 0, 0
475
- distance = []
476
- for t, _ in target_datapoints:
477
- distance.append(
478
- [
479
- 1 - anls_metric(t, p, text_theta)
480
- for p, _ in prediction_datapoints
481
  ]
482
  )
483
- cost_matrix = np.array(distance)
484
- row_ind, col_ind = optimize.linear_sum_assignment(cost_matrix)
485
- score = 0
486
- for r, c in zip(row_ind, col_ind):
487
- score += _get_datapoint_metric(
488
- target_datapoints[r], prediction_datapoints[c], text_theta, number_theta
489
- )
490
- if score == 0:
491
- return 0, 0, 0
492
- precision = score / len(prediction_datapoints)
493
- recall = score / len(target_datapoints)
494
- return precision, recall, 2 * precision * recall / (precision + recall)
495
-
496
- def table_datapoints_precision_recall_per_point( # ๊ฐ๊ฐ ๊ณ„์‚ฐ...
497
- targets,
498
- predictions,
499
- text_theta = 0.5,
500
- number_theta = 0.1,
501
- ):
502
- """Computes precisin recall and F1 metrics given two flattened tables.
503
- Parses each string into a dictionary of keys and values using row and column
504
- headers. Then we match keys between the two dicts as long as their relative
505
- levenshtein distance is below a threshold. Values are also compared with
506
- ANLS if strings or relative distance if they are numeric.
507
- Args:
508
- targets: list of list of strings.
509
- predictions: list of strings.
510
- text_theta: relative edit distance above this is set to the maximum of 1.
511
- number_theta: relative error rate above this is set to the maximum of 1.
512
- Returns:
513
- Dictionary with per-point precision, recall and F1
514
- """
515
- assert len(targets) == len(predictions)
516
- per_point_scores = {"precision": [], "recall": [], "f1": []}
517
- for pred, target in zip(predictions, targets):
518
- all_metrics = []
519
- for transposed in [True, False]:
520
- pred_table = _parse_table(pred, transposed=transposed)
521
- target_table = _parse_table(target, transposed=transposed)
522
-
523
- all_metrics.extend([_table_datapoints_precision_recall_f1(target_table, pred_table, text_theta, number_theta)])
524
 
525
- p, r, f = max(all_metrics, key=lambda x: x[-1])
526
- per_point_scores["precision"].append(p)
527
- per_point_scores["recall"].append(r)
528
- per_point_scores["f1"].append(f)
529
- return per_point_scores
530
 
531
- def table_datapoints_precision_recall( # deplot ์„ฑ๋Šฅ์ง€ํ‘œ
532
- targets,
533
- predictions,
534
- text_theta = 0.5,
535
- number_theta = 0.1,
536
- ):
537
- """Aggregated version of table_datapoints_precision_recall_per_point().
538
- Same as table_datapoints_precision_recall_per_point() but returning aggregated
539
- scores instead of per-point scores.
540
- Args:
541
- targets: list of list of strings.
542
- predictions: list of strings.
543
- text_theta: relative edit distance above this is set to the maximum of 1.
544
- number_theta: relative error rate above this is set to the maximum of 1.
545
- Returns:
546
- Dictionary with aggregated precision, recall and F1
547
- """
548
- score_dict = table_datapoints_precision_recall_per_point(
549
- targets, predictions, text_theta, number_theta
550
- )
551
- return {
552
- "table_datapoints_precision": (
553
- sum(score_dict["precision"]) / len(targets)
554
- ),
555
- "table_datapoints_recall": (
556
- sum(score_dict["recall"]) / len(targets)
557
- ),
558
- "table_datapoints_f1": sum(score_dict["f1"]) / len(targets),
559
- }
560
-
561
- def evaluate_rms(generated_table,label_table):
562
- predictions=[generated_table]
563
- targets=[label_table]
564
- RMS = table_datapoints_precision_recall(targets, predictions)
565
- return RMS
566
-
567
- def ko_deplot_convert_to_dataframe(generated_table_str):
568
- lines = generated_table_str.strip().split(" \n")
569
- headers=[]
570
- data=[]
571
- for i in range(len(lines[1].split(" | "))):
572
- headers.append(f"{i}")
573
- for line in lines[1:len(lines)-1]:
574
- data.append(line.split("| "))
575
- df = pd.DataFrame(data, columns=headers)
576
- return df
577
-
578
- def ko_deplot_convert_to_dataframe2(label_table_str):
579
- lines = label_table_str.strip().split(" \n")
580
- headers=[]
581
  data=[]
582
- for i in range(len(lines[1].split(" | "))):
583
- headers.append(f"{i}")
584
- for line in lines[1:]:
585
- data.append(line.split("| "))
586
- df = pd.DataFrame(data, columns=headers)
587
- return df
588
-
589
- def aihub_deplot_convert_to_dataframe(table_str):
590
- lines = table_str.strip().split("\n")
591
- headers = []
592
- if(len(lines[3].split(" | "))>len(lines[4].split(" | "))):
593
- category=lines[3].split(" | ")
594
- del category[0]
595
- value=lines[4].split(" | ")
596
- df=pd.DataFrame({"๋ฒ”๋ก€":category,"๊ฐ’":value})
597
- return df
598
- else:
599
- for i in range(len(lines[3].split(" | "))):
600
- headers.append(f"{i}")
601
- data = [line.split(" | ") for line in lines[3:]]
602
- df = pd.DataFrame(data, columns=headers)
603
- return df
604
- def unichart_convert_to_dataframe(table_str):
605
- lines=table_str.split(" & ")
606
- headers=[]
607
- data=[]
608
- del lines[0]
609
- for i in range(len(lines[1].split(" | "))):
610
- headers.append(f"{i}")
611
- if lines[0]=="value":
612
- for line in lines[1:]:
613
- data.append(line.split(" | "))
614
- else:
615
- category=lines[0].split(" | ")
616
- category.insert(0," ")
617
- data.append(category)
618
- for line in lines[1:]:
619
- data.append(line.split(" | "))
620
- df=pd.DataFrame(data,columns=headers)
621
- return df
622
-
623
- class Highlighter:
624
- def __init__(self):
625
- self.row = 0
626
- self.col = 0
627
-
628
- def compare_and_highlight(self, pred_table_elem, target_table, pred_table_row, props=''):
629
- if self.row >= pred_table_row:
630
- self.col += 1
631
- self.row = 0
632
- if pred_table_elem != target_table.iloc[self.row, self.col]:
633
- self.row += 1
634
- return props
635
- else:
636
- self.row += 1
637
- return None
638
-
639
- # 1. ๋ฐ์ดํ„ฐ ๋กœ๋“œ
640
- aihub_deplot_result_df = pd.read_csv('./aihub_deplot_result.csv')
641
- ko_deplot_result= './ko-deplot-base-pred-epoch3-refinetuning.json'
642
- unichart_result='./unichart_results.json'
643
-
644
- # 2. ์ฒดํฌํ•ด์•ผ ํ•˜๋Š” ์ด๋ฏธ์ง€ ํŒŒ์ผ ๋กœ๋“œ
645
- def load_image_checklist(file):
646
- with open(file, 'r') as f:
647
- #image_names = [f'"{line.strip()}"' for line in f]
648
- image_names = f.read().splitlines()
649
- return image_names
650
-
651
- # 3. ํ˜„์žฌ ์ธ๋ฑ์Šค๋ฅผ ์ถ”์ ํ•˜๊ธฐ ์œ„ํ•œ ๋ณ€์ˆ˜
652
- current_index = 0
653
- image_names = []
654
- def show_image(current_idx):
655
- image_name=image_names[current_idx]
656
- image_path = f"./top_20_percent_images/{image_name}.jpg"
657
- if not os.path.exists(image_path):
658
- image_path = f"./bottom_20_percent_images/{image_name}.jpg"
659
- return Image.open(image_path)
660
-
661
- # 4. ๋ฒ„ํŠผ ํด๋ฆญ ์ด๋ฒคํŠธ ํ•ธ๋“ค๋Ÿฌ
662
- def non_real_time_check(file):
663
- highlighter1 = Highlighter()
664
- highlighter2 = Highlighter()
665
- highlighter3 = Highlighter()
666
- #global image_names, current_index
667
- #image_names = load_image_checklist(file)
668
- #current_index = 0
669
- #image=show_image(current_index)
670
- file_name =image_names[current_index].replace("Source","Label")
671
-
672
- json_path="./ko_deplot_labeling_data.json"
673
- with open(json_path, 'r', encoding='utf-8') as file:
674
- json_data = json.load(file)
675
- for key, value in json_data.items():
676
- if key == file_name:
677
- ko_deplot_labeling_str=value.get("txt").replace("<0x0A>","\n")
678
- ko_deplot_label_title=ko_deplot_labeling_str.split(" \n ")[0].replace("TITLE | ","์ œ๋ชฉ:")
679
- break
680
-
681
- ko_deplot_rms_path="./ko_deplot_rms.txt"
682
- unichart_rms_path="./unichart_rms.txt"
683
-
684
- json_path="./unichart_labeling_data.json"
685
- with open(json_path, 'r', encoding='utf-8') as file:
686
- json_data = json.load(file)
687
- for entry in json_data:
688
- if entry["imgname"]==image_names[current_index]+".jpg":
689
- unichart_labeling_str=entry["label"]
690
- unichart_label_title=entry["label"].split(" & ")[0].split(" | ")[1]
691
-
692
- with open(ko_deplot_rms_path,'r',encoding='utf-8') as file:
693
- lines=file.readlines()
694
- flag=0
695
- for line in lines:
696
- parts=line.strip().split(", ")
697
- if(len(parts)==2 and parts[0]==image_names[current_index]):
698
- ko_deplot_rms=parts[1]
699
- flag=1
700
- break
701
- if(flag==0):
702
- ko_deplot_rms="none"
703
-
704
- with open(unichart_rms_path,'r',encoding='utf-8') as file:
705
- lines=file.readlines()
706
- flag=0
707
- for line in lines:
708
- parts=line.strip().split(": ")
709
- if(len(parts)==2 and parts[0]==image_names[current_index]+".jpg"):
710
- unichart_rms=parts[1]
711
- flag=1
712
- break
713
- if(flag==0):
714
- unichart_rms="none"
715
-
716
-
717
-
718
- ko_deplot_generated_title,ko_deplot_generated_table=ko_deplot_display_results(current_index)
719
- aihub_deplot_generated_table,aihub_deplot_label_table,aihub_deplot_generated_title,aihub_deplot_label_title=aihub_deplot_display_results(current_index)
720
- unichart_generated_table,unichart_generated_title=unichart_display_results(current_index)
721
- #ko_deplot_RMS=evaluate_rms(ko_deplot_generated_table,ko_deplot_labeling_str)
722
- aihub_deplot_RMS=evaluate_rms(aihub_deplot_generated_table,aihub_deplot_label_table)
723
-
724
-
725
- if flag == 1:
726
- value = [round(float(ko_deplot_rms), 1)]
727
  else:
728
- value = [0]
729
-
730
- ko_deplot_score_table = pd.DataFrame({
731
- 'category': ['f1'],
732
- 'value': value
733
- })
734
-
735
- value=[round(float(unichart_rms)/100,1)]
736
- unichart_score_table=pd.DataFrame({
737
- 'category':['f1'],
738
- 'value':value
739
- })
740
- aihub_deplot_score_table=pd.DataFrame({
741
- 'category': ['precision', 'recall', 'f1'],
742
- 'value': [
743
- round(aihub_deplot_RMS['table_datapoints_precision'],1),
744
- round(aihub_deplot_RMS['table_datapoints_recall'],1),
745
- round(aihub_deplot_RMS['table_datapoints_f1'],1)
746
- ]
747
- })
748
-
749
- #ko_deplot_generated_df=ko_deplot_convert_to_dataframe(ko_deplot_generated_table)
750
- #aihub_deplot_generated_df=aihub_deplot_convert_to_dataframe(aihub_deplot_generated_table)
751
- #unichart_generated_df=unichart_convert_to_dataframe(unichart_generated_table)
752
-
753
  try:
754
- ko_deplot_generated_df=ko_deplot_convert_to_dataframe(ko_deplot_generated_table)
755
- unichart_generated_df=unichart_convert_to_dataframe(unichart_generated_table)
 
 
 
 
 
 
 
 
756
  except Exception as e:
757
- return None,None,None,None,None,None,None,None,None,ko_deplot_generated_table,unichart_generated_table,1
758
- ko_deplot_labeling_df=ko_deplot_convert_to_dataframe2(ko_deplot_labeling_str)
759
- #aihub_deplot_labeling_df=aihub_deplot_convert_to_dataframe(aihub_deplot_label_table)
760
- unichart_labeling_df=unichart_convert_to_dataframe(unichart_labeling_str)
761
-
762
- ko_deplot_generated_df_row=ko_deplot_generated_df.shape[0]
763
- #aihub_deplot_generated_df_row=aihub_deplot_generated_df.shape[0]
764
- unichart_generated_df_row=unichart_generated_df.shape[0]
765
-
766
-
767
- styled_ko_deplot_table=ko_deplot_generated_df.style.applymap(highlighter1.compare_and_highlight,target_table=ko_deplot_labeling_df,pred_table_row=ko_deplot_generated_df_row,props='color:red')
768
-
769
-
770
- #styled_aihub_deplot_table=aihub_deplot_generated_df.style.applymap(highlighter2.compare_and_highlight,target_table=aihub_deplot_labeling_df,pred_table_row=aihub_deplot_generated_df_row,props='color:red')
771
-
772
-
773
- styled_unichart_table=unichart_generated_df.style.applymap(highlighter3.compare_and_highlight,target_table=unichart_labeling_df,pred_table_row=unichart_generated_df_row,props='color:red')
774
-
775
- #return ko_deplot_convert_to_dataframe(ko_deplot_generated_table), aihub_deplot_convert_to_dataframe(aihub_deplot_generated_table), aihub_deplot_convert_to_dataframe(label_table), ko_deplot_score_table, aihub_deplot_score_table
776
- return gr.DataFrame(styled_ko_deplot_table,label=ko_deplot_generated_title+"(VAIV_DePlot ์ถ”๋ก  ๊ฒฐ๊ณผ)"),None,gr.DataFrame(styled_unichart_table,label="์ œ๋ชฉ:"+unichart_generated_title+"(VAIV_UniChart ์ถ”๋ก  ๊ฒฐ๊ณผ)"),gr.DataFrame(ko_deplot_labeling_df,label=ko_deplot_label_title+"(VAIV_DePlot ์ •๋‹ต ํ…Œ์ด๋ธ”)"),None,gr.DataFrame(unichart_labeling_df,label="์ œ๋ชฉ:"+unichart_label_title+"(VAIV_UniChart ์ •๋‹ต ํ…Œ์ด๋ธ”)"),ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table,None,None,0
777
-
778
-
779
- def ko_deplot_display_results(index):
780
- filename=image_names[index]+".jpg"
781
- with open(ko_deplot_result, 'r', encoding='utf-8') as f:
782
- data = json.load(f)
783
- for entry in data:
784
- if entry['filename'].endswith(filename):
785
- #return entry['table']
786
- parts=entry['table'].split("\n",1)
787
- return parts[0].replace("TITLE | ","์ œ๋ชฉ:"),entry['table']
788
-
789
- def aihub_deplot_display_results(index):
790
- if index < 0 or index >= len(image_names):
791
- return "Index out of range", None, None
792
- image_name = image_names[index]
793
- image_row = aihub_deplot_result_df[aihub_deplot_result_df['data_id'] == image_name]
794
- if not image_row.empty:
795
- generated_table = image_row['generated_table'].values[0]
796
- generated_title=generated_table.split("\n")[1]
797
- label_table = image_row['label_table'].values[0]
798
- label_title=label_table.split("\n")[1]
799
- return generated_table, label_table, generated_title, label_title
800
- else:
801
- return "No results found for the image", None, None
802
- def unichart_display_results(index):
803
- image_name=image_names[index]
804
- with open(unichart_result,'r',encoding='utf-8') as f:
805
- data=json.load(f)
806
- for entry in data:
807
- if entry['imgname']==image_name+".jpg":
808
- return entry['label'],entry['label'].split(" & ")[0].split(" | ")[1]
809
-
810
- def previous_image():
811
- global current_index
812
- if current_index>0:
813
- current_index-=1
814
- image=show_image(current_index)
815
- return image, image_names[current_index],gr.update(interactive=current_index>0), gr.update(interactive=current_index<len(image_names)-1)
816
-
817
- def next_image():
818
- global current_index
819
- if current_index<len(image_names)-1:
820
- current_index+=1
821
- image=show_image(current_index)
822
- return image, image_names[current_index],gr.update(interactive=current_index>0), gr.update(interactive=current_index<len(image_names)-1)
823
 
824
  def real_time_check(image_file):
825
- highlighter1 = Highlighter()
826
- highlighter2 = Highlighter()
827
- highlighter3=Highlighter()
828
  image = Image.open(image_file)
829
-
830
- result_model1 = predict_model1(image)
831
- parts=result_model1.split("\n")
832
  del parts[-1]
833
- result_model1="\n".join(parts)
834
- ko_deplot_generated_title=result_model1.split("\n")[0].split(" | ")[1]
835
- #ko_deplot_table=ko_deplot_convert_to_dataframe2(result_model1)
836
-
837
- result_model3=predict_model3(image)
838
- #unichart_table=unichart_convert_to_dataframe(result_model3)
839
- unichart_generated_title=result_model3.split(" & ")[0].split(" | ")[1]
840
-
841
  try:
842
- ko_deplot_table=ko_deplot_convert_to_dataframe2(result_model1)
843
- unichart_table=unichart_convert_to_dataframe(result_model3)
 
844
  except Exception as e:
845
- return None,None,None,None,None,None,None,None,None,result_model1,result_model3,1
846
-
847
- #aihub_labeling_data_json="./labeling_data/"+file_name+".json"
848
- if os.path.basename(image_file.name).startswith("C_Source"):
849
- image_base_name = os.path.basename(image_file.name).replace("Source","Label")
850
- file_name, _ = os.path.splitext(image_base_name)
851
- json_path="./ko_deplot_labeling_data.json"
852
- with open(json_path, 'r', encoding='utf-8') as file:
853
- json_data = json.load(file)
854
- for key, value in json_data.items():
855
- if key == file_name:
856
- ko_deplot_labeling_str=value.get("txt").replace("<0x0A>","\n")
857
- ko_deplot_label_title=ko_deplot_labeling_str.split(" \n ")[0].split(" | ")[1]
858
- break
859
-
860
- ko_deplot_label_table=ko_deplot_convert_to_dataframe2(ko_deplot_labeling_str)
861
-
862
- #aihub_deplot_labeling_str=process_json_file2(aihub_labeling_data_json)
863
- #aihub_deplot_label_title=aihub_deplot_labeling_str.split("\n")[1].split(":")[1]
864
-
865
- json_path="./unichart_labeling_data.json"
866
- with open(json_path, 'r', encoding='utf-8') as file:
867
- json_data = json.load(file)
868
- for entry in json_data:
869
- if entry["imgname"]==os.path.basename(image_file.name):
870
- unichart_labeling_str=entry["label"]
871
- unichart_label_title=entry["label"].split(" & ")[0].split(" | ")[1]
872
- unichart_label_table=unichart_convert_to_dataframe(unichart_labeling_str)
873
-
874
- ko_deplot_RMS=evaluate_rms(result_model1,ko_deplot_labeling_str)
875
- unichart_RMS=evaluate_rms(result_model3.replace("Characteristic","Title").replace("&","\n"),unichart_labeling_str.replace("Characteristic","Title").replace("&","\n"))
876
- ko_deplot_score_table=pd.DataFrame({
877
- 'category': ['precision', 'recall', 'f1'],
878
- 'value': [
879
- round(ko_deplot_RMS['table_datapoints_precision'],1),
880
- round(ko_deplot_RMS['table_datapoints_recall'],1),
881
- round(ko_deplot_RMS['table_datapoints_f1'],1)
882
- ]
883
- })
884
- unichart_score_table=pd.DataFrame({
885
- 'category': ['precision', 'recall', 'f1'],
886
- 'value': [
887
- round(unichart_RMS['table_datapoints_precision'],1),
888
- round(unichart_RMS['table_datapoints_recall'],1),
889
- round(unichart_RMS['table_datapoints_f1'],1)
890
- ]
891
- })
892
 
893
- ko_deplot_generated_df_row=ko_deplot_table.shape[0]
894
- unichart_generated_df_row=unichart_table.shape[0]
895
- styled_ko_deplot_table=ko_deplot_table.style.applymap(highlighter1.compare_and_highlight,target_table=ko_deplot_label_table,pred_table_row=ko_deplot_generated_df_row,props='color:red')
896
- styled_unichart_table=unichart_table.style.applymap(highlighter3.compare_and_highlight,target_table=unichart_label_table,pred_table_row=unichart_generated_df_row,props='color:red')
897
- return gr.DataFrame(styled_ko_deplot_table,label=ko_deplot_generated_title+"(VAIV_DePlot ์ถ”๋ก  ๊ฒฐ๊ณผ)") ,None,gr.DataFrame(styled_unichart_table,label=unichart_generated_title+"(VAIV_UniChart ์ถ”๋ก  ๊ฒฐ๊ณผ)"),gr.DataFrame(ko_deplot_label_table,label=ko_deplot_label_title+"(VAIV_DePlot ์ •๋‹ต ํ…Œ์ด๋ธ”)"),None,gr.DataFrame(unichart_label_table,label=unichart_label_title+"(VAIV_UniChart ์ •๋‹ต ํ…Œ์ด๋ธ”)"),ko_deplot_score_table,None,unichart_score_table,None,None,0
898
- else:
899
- return gr.DataFrame(ko_deplot_table,label=ko_deplot_generated_title+"(VAIV_DePlot ์ถ”๋ก  ๊ฒฐ๊ณผ)"),None,gr.DataFrame(unichart_table,label=unichart_generated_title+"(VAIV_UniChart ์ถ”๋ก  ๊ฒฐ๊ณผ)"),None,None,None,None,None,None,None,None,0
900
- def inference(mode,image_uploader,file_uploader):
901
- if(mode=="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ"):
902
- ko_deplot_table, aihub_deplot_table, unichart_table, ko_deplot_label_table,aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table,ko_deplot_generated_txt,unichart_generated_txt,flag= real_time_check(image_uploader)
903
  if flag==1:
904
- return ko_deplot_table, aihub_deplot_table, unichart_table,ko_deplot_label_table, aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table,gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(unichart_generated_txt,visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False)
905
  else:
906
- return ko_deplot_table, aihub_deplot_table, unichart_table,ko_deplot_label_table, aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table,gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False)
907
  else:
908
- styled_ko_deplot_table,styled_aihub_deplot_table,styled_unichart_table,ko_deplot_label_table,aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table,aihub_deplot_score_table, unichart_score_table,ko_deplot_generated_txt,unichart_generated_txt,flag=non_real_time_check(file_uploader)
909
  if flag==1:
910
- return styled_ko_deplot_table, styled_aihub_deplot_table, styled_unichart_table,ko_deplot_label_table,aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table, aihub_deplot_score_table, unichart_score_table,gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(unichart_generated_txt,visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False)
911
  else:
912
- return styled_ko_deplot_table, styled_aihub_deplot_table, styled_unichart_table,ko_deplot_label_table,aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table, aihub_deplot_score_table, unichart_score_table,gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False)
913
- def interface_selector(selector):
914
- if selector == "์ด๋ฏธ์ง€ ์—…๋กœ๋“œ":
915
- return gr.update(visible=True),gr.update(visible=False),gr.State("image_upload"),gr.update(visible=False),gr.update(visible=False),gr.File("./new_top_20_percent_images.txt"),"high score ์ฐจํŠธ"
916
- elif selector == "ํŒŒ์ผ ์—…๋กœ๋“œ":
917
- return gr.update(visible=False),gr.update(visible=True),gr.State("file_upload"), gr.update(visible=True),gr.update(visible=True),gr.File("./new_top_20_percent_images.txt"),"high score ์ฐจํŠธ"
918
-
919
- def file_selector(selector):
920
- if selector == "low score ์ฐจํŠธ":
921
- return gr.File("./new_bottom_20_percent_images.txt"),"์ „์ฒด"
922
- elif selector == "high score ์ฐจํŠธ":
923
- return gr.File("./new_top_20_percent_images.txt"),"์ „์ฒด"
924
- '''
925
- def update_results(model_type):
926
- if "ko_deplot" == model_type:
927
- return gr.update(visible=True),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False)
928
- elif "aihub_deplot" == model_type:
929
- return gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False)
930
- elif "unichart"==model_type:
931
- return gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True)
932
- else:
933
- return gr.update(visible=True), gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True)
934
- '''
935
-
936
- def update_results(selected_models):
937
  # Create a visibility list initialized to False for all components
938
- visibility = [False] * 9
939
-
940
  # Update visibility based on the selected models
941
  if "VAIV_DePlot" in selected_models:
942
- visibility[0] = True # ko_deplot_generated_table
943
- visibility[1] = True # ko_deplot_score_table
944
- visibility[6] = True # ko_deplot_label_table
945
- '''
946
- if "aihub_deplot" in selected_models:
947
- visibility[2] = True # aihub_deplot_generated_table
948
- visibility[3] = True # aihub_deplot_score_table
949
- visibility[7] = True # aihub_deplot_label_table
950
- '''
951
- if "VAIV_UniChart" in selected_models:
952
- visibility[4] = True # unichart_generated_table
953
- visibility[5] = True # unichart_score_table
954
- visibility[8] = True # unichart_label_table
955
-
956
  if "all" in selected_models:
957
- visibility[0] = True # ko_deplot_generated_table
958
- visibility[1] = True # ko_deplot_score_table
959
- visibility[6] = True # ko_deplot_label_table
960
- visibility[4] = True # unichart_generated_table
961
- visibility[5] = True # unichart_score_table
962
- visibility[8] = True # unichart_label_table
963
-
 
964
  # Return gr.update for each component with the corresponding visibility status
965
  return tuple(gr.update(visible=v) for v in visibility)
966
 
 
 
 
 
 
967
 
968
  def display_image(image_file):
969
  image=Image.open(image_file)
970
  return image, os.path.basename(image_file)
971
 
972
- def display_image_in_file(image_checklist):
973
- global image_names, current_index
974
- image_names = load_image_checklist(image_checklist)
975
- image=show_image(current_index)
976
- return image,image_names[current_index]
977
-
978
- def update_file_based_on_chart_type(chart_type, all_file_path):
979
- with open(all_file_path, 'r', encoding='utf-8') as file:
980
- lines = file.readlines()
981
- filtered_lines=[]
982
- if chart_type == "์ „์ฒด":
983
- filtered_lines = lines
984
- elif chart_type == "์ผ๋ฐ˜ ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•":
985
- filtered_lines = [line for line in lines if "_horizontal bar_standard" in line]
986
- elif chart_type=="๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•":
987
- filtered_lines = [line for line in lines if "_horizontal bar_accumulation" in line]
988
- elif chart_type=="100% ๊ธฐ์ค€ ๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•":
989
- filtered_lines = [line for line in lines if "_horizontal bar_100per accumulation" in line]
990
- elif chart_type=="์ผ๋ฐ˜ ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•":
991
- filtered_lines = [line for line in lines if "_vertical bar_standard" in line]
992
- elif chart_type=="๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•":
993
- filtered_lines = [line for line in lines if "_vertical bar_accumulation" in line]
994
- elif chart_type=="100% ๊ธฐ์ค€ ๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•":
995
- filtered_lines = [line for line in lines if "_vertical bar_100per accumulation" in line]
996
- elif chart_type=="์„ ํ˜•":
997
- filtered_lines = [line for line in lines if "_line_standard" in line]
998
- elif chart_type=="์›ํ˜•":
999
- filtered_lines = [line for line in lines if "_pie_standard" in line]
1000
- elif chart_type=="๊ธฐํƒ€ ๋ฐฉ์‚ฌํ˜•":
1001
- filtered_lines = [line for line in lines if "_etc_radial" in line]
1002
- elif chart_type=="๊ธฐํƒ€ ํ˜ผํ•ฉํ˜•":
1003
- filtered_lines = [line for line in lines if "_etc_mix" in line]
1004
- # ์ƒˆ๋กœ์šด ํŒŒ์ผ์— ๊ธฐ๋ก
1005
- new_file_path = "./filtered_chart_images.txt"
1006
- with open(new_file_path, 'w', encoding='utf-8') as file:
1007
- file.writelines(filtered_lines)
1008
-
1009
- return new_file_path
1010
 
1011
- def handle_chart_type_change(chart_type,all_file_path):
1012
- new_file_path = update_file_based_on_chart_type(chart_type, all_file_path)
1013
- global image_names, current_index
1014
- image_names = load_image_checklist(new_file_path)
1015
- current_index=0
1016
- image=show_image(current_index)
1017
- return image,image_names[current_index]
 
 
 
 
 
 
 
 
 
 
1018
 
1019
  css = """
1020
  .dataframe-class {
1021
- height: 300px; /* ๋†’์ด๋ฅผ ๊ณ ์ • */
1022
  overflow-y: auto !important; /* ์Šคํฌ๋กค์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ */
 
1023
  }
1024
  """
1025
 
1026
  with gr.Blocks(css=css) as iface:
1027
- mode=gr.State("image_upload")
 
 
1028
  with gr.Row():
1029
  with gr.Column():
1030
- #mode_label=gr.Text("์ด๋ฏธ์ง€ ์—…๋กœ๋“œ๊ฐ€ ์„ ํƒ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
1031
- upload_option = gr.Radio(choices=["์ด๋ฏธ์ง€ ์—…๋กœ๋“œ", "ํŒŒ์ผ ์—…๋กœ๋“œ"], value="์ด๋ฏธ์ง€ ์—…๋กœ๋“œ", label="์—…๋กœ๋“œ ์˜ต์…˜")
1032
- #with gr.Row():
1033
- #image_button = gr.Button("์ด๋ฏธ์ง€ ์—…๋กœ๋“œ")
1034
- #file_button = gr.Button("ํŒŒ์ผ ์—…๋กœ๋“œ")
1035
-
1036
- # ์ด๋ฏธ์ง€์™€ ํŒŒ์ผ ์—…๋กœ๋“œ ์ปดํฌ๋„ŒํŠธ (์ดˆ๊ธฐ์—๋Š” ์ˆจ๊น€ ์ƒํƒœ)
1037
- # global image_uploader,file_uploader
1038
- image_uploader= gr.File(file_count="single",file_types=["image"],visible=True)
1039
- file_uploader= gr.File(file_count="single", file_types=[".txt"], visible=False)
1040
- file_upload_option=gr.Radio(choices=["low score ์ฐจํŠธ","high score ์ฐจํŠธ"],label="ํŒŒ์ผ ์—…๋กœ๋“œ ์˜ต์…˜",visible=False)
1041
- chart_type = gr.Dropdown(["์ผ๋ฐ˜ ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•","๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•","100% ๊ธฐ์ค€ ๋ˆ„์  ๊ฐ€๋กœ ๋ง‰๋Œ€ํ˜•", "์ผ๋ฐ˜ ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•","๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•","100% ๊ธฐ์ค€ ๋ˆ„์  ์„ธ๋กœ ๋ง‰๋Œ€ํ˜•","์„ ํ˜•", "์›ํ˜•", "๊ธฐํƒ€ ๋ฐฉ์‚ฌํ˜•", "๊ธฐํƒ€ ํ˜ผํ•ฉํ˜•", "์ „์ฒด"], label="Chart Type", value="all")
1042
- model_type=gr.Dropdown(["VAIV_DePlot","VAIV_UniChart","all"],value="VAIV_DePlot",label="model",multiselect=True)
1043
- image_displayer=gr.Image(visible=True)
1044
  with gr.Row():
1045
- pre_button=gr.Button("์ด์ „",interactive="False")
1046
- next_button=gr.Button("๋‹ค์Œ")
1047
- image_name=gr.Text("์ด๋ฏธ์ง€ ์ด๋ฆ„",visible=False)
1048
- #image_button.click(interface_selector, inputs=gr.State("์ด๋ฏธ์ง€ ์—…๋กœ๋“œ"), outputs=[image_uploader,file_uploader,mode,mode_label,image_name])
1049
- #file_button.click(interface_selector, inputs=gr.State("ํŒŒ์ผ ์—…๋กœ๋“œ"), outputs=[image_uploader, file_uploader,mode,mode_label,image_name])
1050
- inference_button=gr.Button("์ถ”๋ก ")
1051
- with gr.Column():
1052
- ko_deplot_generated_table=gr.DataFrame(visible=True,label="VAIV_DePlot ์ถ”๋ก  ๊ฒฐ๊ณผ",elem_classes="dataframe-class")
1053
- aihub_deplot_generated_table=gr.DataFrame(visible=False,label="aihub-deplot ์ถ”๋ก  ๊ฒฐ๊ณผ",elem_classes="dataframe-class")
1054
- unichart_generated_table=gr.DataFrame(visible=False,label="VAIV_UniChart ์ถ”๋ก  ๊ฒฐ๊ณผ",elem_classes="dataframe-class")
1055
- ko_deplot_generated_txt=gr.Text(visible=False,label="VAIV_DePlot ์ถ”๋ก  ๊ฒฐ๊ณผ")
1056
- unichart_generated_txt=gr.Text(visible=False,label="VAIV_UniChart ์ถ”๋ก  ๊ฒฐ๊ณผ")
1057
- with gr.Column():
1058
- ko_deplot_label_table=gr.DataFrame(visible=True,label="VAIV_DePlot ์ •๋‹ตํ…Œ์ด๋ธ”",elem_classes="dataframe-class")
1059
- aihub_deplot_label_table=gr.DataFrame(visible=False,label="aihub-deplot ์ •๋‹ตํ…Œ์ด๋ธ”",elem_classes="dataframe-class")
1060
- unichart_label_table=gr.DataFrame(visible=False,label="VAIV_UniChart ์ •๋‹ตํ…Œ์ด๋ธ”",elem_classes="dataframe-class")
1061
  with gr.Column():
1062
- ko_deplot_score_table=gr.DataFrame(visible=True,label="VAIV_DePlot ์ ์ˆ˜",elem_classes="dataframe-class")
1063
- aihub_deplot_score_table=gr.DataFrame(visible=False,label="aihub_deplot ์ ์ˆ˜",elem_classes="dataframe-class")
1064
- unichart_score_table=gr.DataFrame(visible=False,label="VAIV_UniChart ์ ์ˆ˜",elem_classes="dataframe-class")
1065
- model_type.change(
1066
- update_results,
1067
- inputs=[model_type],
1068
- outputs=[ko_deplot_generated_table,ko_deplot_score_table,aihub_deplot_generated_table,aihub_deplot_score_table,unichart_generated_table,unichart_score_table,ko_deplot_label_table,aihub_deplot_label_table,unichart_label_table]
1069
- )
1070
-
1071
- upload_option.change(
1072
- interface_selector,
1073
- inputs=[upload_option],
1074
- outputs=[image_uploader, file_uploader, mode, image_name,file_upload_option,file_uploader,file_upload_option]
1075
- )
1076
 
1077
- file_upload_option.change(
1078
- file_selector,
1079
- inputs=[file_upload_option],
1080
- outputs=[file_uploader,chart_type]
 
 
 
 
 
 
1081
  )
1082
 
1083
- chart_type.change(handle_chart_type_change, inputs=[chart_type,file_uploader],outputs=[image_displayer,image_name])
1084
  image_uploader.upload(display_image,inputs=[image_uploader],outputs=[image_displayer,image_name])
1085
- file_uploader.change(display_image_in_file,inputs=[file_uploader],outputs=[image_displayer,image_name])
1086
- pre_button.click(previous_image, outputs=[image_displayer,image_name,pre_button,next_button])
1087
- next_button.click(next_image, outputs=[image_displayer,image_name,pre_button,next_button])
1088
- inference_button.click(inference,inputs=[upload_option,image_uploader,file_uploader],outputs=[ko_deplot_generated_table, aihub_deplot_generated_table, unichart_generated_table, ko_deplot_label_table, aihub_deplot_label_table, unichart_label_table, ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table,ko_deplot_generated_txt,unichart_generated_txt,ko_deplot_generated_table, aihub_deplot_generated_table, unichart_generated_table, ko_deplot_label_table, aihub_deplot_label_table, unichart_label_table, ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table])
1089
 
1090
- if __name__ == "__main__":
1091
- print("Launching Gradio interface...")
1092
- sys.stdout.flush() # stdout ๋ฒ„ํผ๋ฅผ ๋น„์›๋‹ˆ๋‹ค.
1093
- iface.launch(share=True)
1094
- #iface.launch(share=False,server_name="115.145.230.14",server_port=8080)
1095
- time.sleep(2) # Gradio URL์ด ์ถœ๋ ฅ๋  ๋•Œ๊นŒ์ง€ ์ž ์‹œ ๊ธฐ๋‹ค๋ฆฝ๋‹ˆ๋‹ค.
1096
- sys.stdout.flush() # ๋‹ค์‹œ stdout ๋ฒ„ํผ๋ฅผ ๋น„์›๋‹ˆ๋‹ค.
1097
- # Gradio๊ฐ€ ์ œ๊ณตํ•˜๋Š” URLs์„ ํŒŒ์ผ์— ๊ธฐ๋กํ•ฉ๋‹ˆ๋‹ค.
1098
- with open("gradio_url.log", "w") as f:
1099
- print(iface.local_url, file=f)
1100
- print(iface.share_url, file=f)
 
20
  import logging
21
  import subprocess
22
  import spaces
23
+ import openai
24
+ import base64
25
+ from io import StringIO
26
 
27
  # Git LFS pull ๋ช…๋ น์–ด ์‹คํ–‰
28
  result = subprocess.run(['git', 'lfs', 'pull'], capture_output=True, text=True)
 
39
  warnings.filterwarnings('ignore')
40
  MAX_PATCHES = 512
41
  # Load the models and processor
 
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
 
44
  # Paths to the models
45
+ ko_deplot_model_path = './deplot_model_ver_24.11.21_korean_only(exclude NUUA)_epoch1.bin'
 
 
46
 
47
  # Load first model ko-deplot
 
48
  def load_model1():
49
  processor1 = Pix2StructProcessor.from_pretrained('nuua/ko-deplot')
50
  model1 = Pix2StructForConditionalGeneration.from_pretrained('nuua/ko-deplot')
51
  model1.load_state_dict(torch.load(ko_deplot_model_path, map_location="cpu"))
52
  model1.to(torch.device("cuda"))
53
+ return processor1, model1
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ processor1, model1 = load_model1()
56
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # Function to format output
58
  def format_output(prediction):
59
  return prediction.replace('<0x0A>', '\n')
60
 
61
+ # First model prediction: ko-deplot
 
62
  def predict_model1(image):
63
  images = [image]
64
  inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
 
72
  formatted_output = format_output(outputs[0])
73
  return formatted_output
74
 
75
+ # Set your OpenAI API key
76
+ openai.api_key = "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"
77
+
78
+ # Function to encode the image as base64
79
+ def encode_image(image_path):
80
+ with open(image_path, "rb") as image_file:
81
+ return base64.b64encode(image_file.read()).decode("utf-8")
82
+
83
+ # Second model prediction: gpt-4o-mini
84
+ def predict_model2(image):
85
+ # Encode the uploaded image to base64
86
+ image_data = encode_image(image)
87
+
88
+ # Prepare the request content
89
+ response = openai.ChatCompletion.create(
90
+ model="gpt-4o-mini",
91
+ messages=[
92
+ {
93
+ "role": "user",
94
+ "content": [
95
+ {
96
+ "type": "text",
97
+ "text": "please extract chart title and chart data manually and present them as a table. you should only provide title and table without adding any additional comments such as **Chart Title:** ."
98
+ },
99
+ {
100
+ "type": "image_url",
101
+ "image_url": {
102
+ "url": f"data:image/jpeg;base64,{image_data}"
103
+ }
104
+ }
105
+ ]
106
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  ]
108
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ # Return the table data from the response
111
+ return response.choices[0]["message"]["content"]
 
 
 
112
 
113
+ def ko_deplot_convert_to_dataframe(label_table_str): #function that converts text generated by ko-deplot to pandas dataframe
114
+ lines = label_table_str.strip().split("\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  data=[]
116
+ title= lines[0].split(" | ")[1]
117
+
118
+ if(len(lines[1].split("|")) == len(lines[2].split("|"))):
119
+ headers=lines[1].split(" | ")
120
+ for line in lines[2:]:
121
+ data.append(line.split(" | "))
122
+ df = pd.DataFrame(data, columns=headers)
123
+ return df, title
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  else:
125
+ legend_row=lines[1].split("|")
126
+ legend_row.insert(0," ")
127
+ for line in lines[2:]:
128
+ data.append(line.split(" | "))
129
+ df = pd.DataFrame(data, columns=legend_row)
130
+ return df, title
131
+
132
+ def gpt_convert_to_dataframe(table_text): #function that converts text generated by gpt to pandas dataframe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  try:
134
+ # Split the text into lines
135
+ lines = table_text.strip().split("\n")
136
+ title=lines[0]
137
+ lines.pop(1)
138
+ lines.pop(2)
139
+ # Process the remaining lines to create the DataFrame
140
+ data = [line.split("|")[1:-1] for line in lines[1:]] # Split by | and remove empty first/last items
141
+ dataframe = pd.DataFrame(data[1:], columns=[col.strip() for col in data[0]]) # Use the first row as headers
142
+
143
+ return dataframe, title
144
  except Exception as e:
145
+ return f"Error converting table to DataFrame: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  def real_time_check(image_file):
 
 
 
148
  image = Image.open(image_file)
149
+ ko_deplot_generated_txt = predict_model1(image)
150
+ parts=ko_deplot_generated_txt.split("\n")
 
151
  del parts[-1]
152
+ ko_deplot_generated_txt="\n".join(parts)
153
+ gpt_generated_txt=predict_model2(image_file)
 
 
 
 
 
 
154
  try:
155
+ ko_deplot_generated_df, ko_deplot_generated_title=ko_deplot_convert_to_dataframe(ko_deplot_generated_txt)
156
+ gpt_generated_df, gpt_generated_title=gpt_convert_to_dataframe(gpt_generated_txt)
157
+ return gr.DataFrame(ko_deplot_generated_df, label= ko_deplot_generated_title), gr.DataFrame(gpt_generated_df, label= gpt_generated_title), None,None,0
158
  except Exception as e:
159
+ return None,None,ko_deplot_generated_txt,gpt_generated_txt,1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ flag = 0 #flag to check whether exception happens or not. if flag is 1, it means that exception(generated txt cannot be converted to pandas dataframe) happens.
162
+ def inference(image_uploader,mode_selector):
163
+ if(mode_selector=="ํŒŒ์ผ ์—…๋กœ๋“œ"):
164
+ ko_deplot_generated_df, gpt_generated_df,ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_uploader)
 
 
 
 
 
 
165
  if flag==1:
166
+ return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True)
167
  else:
168
+ return ko_deplot_generated_df, gpt_generated_df, gr.update(visible=False),gr.update(visible=False)
169
  else:
170
+ ko_deplot_generated_df, gpt_generated_df,ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_files[current_image_index])
171
  if flag==1:
172
+ return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True)
173
  else:
174
+ return ko_deplot_generated_df, gpt_generated_df, gr.update(visible=False),gr.update(visible=False)
175
+
176
+ def toggle_model(selected_models,flag):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  # Create a visibility list initialized to False for all components
178
+ visibility = [False] * 6
 
179
  # Update visibility based on the selected models
180
  if "VAIV_DePlot" in selected_models:
181
+ visibility[4]= True
182
+ if flag:
183
+ visibility[2]= True
184
+ else:
185
+ visibility[0]= True
186
+ if "gpt-4o-mini" in selected_models:
187
+ visibility[5]= True
188
+ if flag:
189
+ visibility[3]= True
190
+ else:
191
+ visibility[1]= True
 
 
 
192
  if "all" in selected_models:
193
+ visibility[4]=True
194
+ visibility[5]=True
195
+ if flag:
196
+ visibility[2]= True
197
+ visibility[3]= True
198
+ else:
199
+ visibility[0]= True
200
+ visibility[1]= True
201
  # Return gr.update for each component with the corresponding visibility status
202
  return tuple(gr.update(visible=v) for v in visibility)
203
 
204
+ def toggle_mode(mode):
205
+ if mode == "ํŒŒ์ผ ์—…๋กœ๋“œ":
206
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
207
+ else:
208
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
209
 
210
  def display_image(image_file):
211
  image=Image.open(image_file)
212
  return image, os.path.basename(image_file)
213
 
214
+ # Function to display the images in the folder sequentially
215
+ image_files = []
216
+ current_image_index = 0
217
+ image_files_cnt=0
218
+
219
+ def display_folder_images(image_file_path_list):
220
+ global image_files, current_image_index,image_files_cnt
221
+ image_files = image_file_path_list
222
+ image_files_cnt=len(image_files)
223
+ current_image_index = 0
224
+ if image_files:
225
+ return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=False), gr.update(interactive=True)
226
+ return None, "No images found"
227
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
+ def next_image():
230
+ global current_image_index
231
+ if image_files:
232
+ current_image_index = (current_image_index + 1)
233
+ prev_disabled = current_image_index == 0
234
+ next_disabled = current_image_index == (len(image_files) - 1)
235
+ return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled)
236
+ return None, "No images found"
237
+
238
+ def prev_image():
239
+ global current_image_index
240
+ if image_files:
241
+ current_image_index = (current_image_index - 1)
242
+ prev_disabled = current_image_index == 0
243
+ next_disabled = current_image_index == (len(image_files) - 1)
244
+ return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled)
245
+ return None, "No images found"
246
 
247
  css = """
248
  .dataframe-class {
 
249
  overflow-y: auto !important; /* ์Šคํฌ๋กค์„ ๊ฐ€๋Šฅํ•˜๊ฒŒ */
250
+ height: 250px
251
  }
252
  """
253
 
254
  with gr.Blocks(css=css) as iface:
255
+ with gr.Row():
256
+ gr.Markdown("<h1 style='text-align: center;'>SKKU-VAIV Automatic chart understanding evaluation tool</h1>")
257
+ gr.Markdown("<hr style='border: 1px solid #ddd;' />")
258
  with gr.Row():
259
  with gr.Column():
260
+ mode_selector = gr.Radio(["ํŒŒ์ผ ์—…๋กœ๋“œ", "ํด๋” ์—…๋กœ๋“œ"], label="Upload Mode", value="ํŒŒ์ผ ์—…๋กœ๋“œ")
261
+ image_uploader = gr.File(file_count="single", file_types=["image"], visible=True)
262
+ folder_uploader = gr.File(file_count="directory", file_types=["image"], visible=False, height=50)
263
+ model_type=gr.Dropdown(["VAIV_DePlot","gpt-4o-mini","all"],value="VAIV_DePlot",label="model",multiselect=True)
264
+ image_displayer = gr.Image(visible=True)
265
+ image_name = gr.Text("", visible=True)
 
 
 
 
 
 
 
 
266
  with gr.Row():
267
+ prev_button = gr.Button("์ด์ „", visible=False, interactive=False)
268
+ next_button = gr.Button("๋‹ค์Œ", visible=False, interactive=False)
269
+ inference_button = gr.Button("์ถ”๋ก ")
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  with gr.Column():
271
+ md1 = gr.Markdown("# VAIV_DePlot Inference Result")
272
+ ko_deplot_generated_df = gr.DataFrame(visible=True, elem_classes="dataframe-class")
273
+ ko_deplot_generated_txt = gr.Text(visible=False)
274
+ with gr.Column():
275
+ md2 = gr.Markdown("# gpt-4o-mini Inference Result", visible=False)
276
+ gpt_generated_df = gr.DataFrame(visible=False, elem_classes="dataframe-class")
277
+ gpt_generated_txt = gr.Text(visible=False)
278
+ #label_df = gr.DataFrame(visible=False, label="Ground Truth Table", elem_classes="dataframe-class",scale=1)
 
 
 
 
 
 
279
 
280
+ model_type.change(
281
+ toggle_model,
282
+ inputs=[model_type, gr.State(flag)],
283
+ outputs=[ko_deplot_generated_df,gpt_generated_df,ko_deplot_generated_txt,gpt_generated_txt,md1,md2]
284
+ )
285
+
286
+ mode_selector.change(
287
+ toggle_mode,
288
+ inputs=[mode_selector],
289
+ outputs=[image_uploader, folder_uploader, prev_button, next_button]
290
  )
291
 
 
292
  image_uploader.upload(display_image,inputs=[image_uploader],outputs=[image_displayer,image_name])
293
+ folder_uploader.upload(display_folder_images, inputs=[folder_uploader], outputs=[image_displayer, image_name, prev_button, next_button])
294
+ prev_button.click(prev_image, outputs=[image_displayer, image_name, prev_button, next_button])
295
+ next_button.click(next_image, outputs=[image_displayer, image_name, prev_button, next_button])
296
+ inference_button.click(inference,inputs=[image_uploader,mode_selector],outputs=[ko_deplot_generated_df, gpt_generated_df, ko_deplot_generated_txt, gpt_generated_txt])
297
 
298
+ if __name__ == "__main__":
299
+ iface.launch(share=True)