nicpopovic commited on
Commit
ee5a50d
·
verified ·
1 Parent(s): 86336aa

Upload 13 files

Browse files
app.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, STOKEStreamer
3
+ from threading import Thread
4
+ import json
5
+ import torch
6
+ import os
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib.colors import to_hex
10
+ import itertools
11
+ import transformers
12
+ transformers.logging.set_verbosity_error()
13
+
14
+
15
+ # Variable to define number of instances
16
+ n_instances = 1
17
+
18
+ gpu_name = "CPU"
19
+
20
+ for i in range(torch.cuda.device_count()):
21
+ gpu_name = torch.cuda.get_device_properties(i).name
22
+
23
+ # Reusing the original MLP class and other functions (unchanged) except those specific to Streamlit
24
+ class MLP(torch.nn.Module):
25
+ def __init__(self, input_dim, output_dim, hidden_dim=1024, layer_id=0, cuda=False):
26
+ super(MLP, self).__init__()
27
+ self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
28
+ self.fc3 = torch.nn.Linear(hidden_dim, output_dim)
29
+ self.layer_id = layer_id
30
+ if cuda:
31
+ self.device = "cuda"
32
+ else:
33
+ self.device = "cpu"
34
+ self.to(self.device)
35
+
36
+ def forward(self, x):
37
+ x = torch.flatten(x, start_dim=1)
38
+ x = torch.relu(self.fc1(x))
39
+ x = self.fc3(x)
40
+ return torch.argmax(x, dim=-1).cpu().detach(), torch.softmax(x, dim=-1).cpu().detach()
41
+
42
+ def map_value_to_color(value, colormap_name='tab20c'):
43
+ value = np.clip(value, 0.0, 1.0)
44
+ colormap = plt.get_cmap(colormap_name)
45
+ rgba_color = colormap(value)
46
+ css_color = to_hex(rgba_color)
47
+ return css_color + "88"
48
+
49
+
50
+ # Caching functions for model and classifier
51
+ model_cache = {}
52
+
53
+ def get_multiple_model_and_tokenizer(name, n_instances):
54
+ model_instances = []
55
+ for _ in range(n_instances):
56
+ tok = AutoTokenizer.from_pretrained(name, token=os.getenv('HF_TOKEN'), pad_token_id=128001)
57
+ model = AutoModelForCausalLM.from_pretrained(name, token=os.getenv('HF_TOKEN'), torch_dtype="bfloat16", pad_token_id=128001, device_map="auto")
58
+ if torch.cuda.is_available():
59
+ model.cuda()
60
+ model_instances.append((model, tok))
61
+ return model_instances
62
+
63
+ def get_classifiers_for_model(att_size, emb_size, device, config_paths):
64
+ config = {
65
+ "classifier_token": json.load(open(os.path.join(config_paths["classifier_token"], "config.json"), "r")),
66
+ "classifier_span": json.load(open(os.path.join(config_paths["classifier_span"], "config.json"), "r"))
67
+ }
68
+ layer_id = config["classifier_token"]["layer"]
69
+
70
+ classifier_span = MLP(att_size, 2, hidden_dim=config["classifier_span"]["classifier_dim"]).to(device)
71
+ classifier_span.load_state_dict(torch.load(os.path.join(config_paths["classifier_span"], "checkpoint.pt"), map_location=device, weights_only=True))
72
+
73
+ classifier_token = MLP(emb_size, len(config["classifier_token"]["label_map"]), layer_id=layer_id, hidden_dim=config["classifier_token"]["classifier_dim"]).to(device)
74
+ classifier_token.load_state_dict(torch.load(os.path.join(config_paths["classifier_token"], "checkpoint.pt"), map_location=device, weights_only=True))
75
+
76
+ return classifier_span, classifier_token, config["classifier_token"]["label_map"]
77
+
78
+ def find_datasets_and_model_ids(root_dir):
79
+ datasets = {}
80
+ for root, dirs, files in os.walk(root_dir):
81
+ if 'config.json' in files and 'stoke_config.json' in files:
82
+ config_path = os.path.join(root, 'config.json')
83
+ stoke_config_path = os.path.join(root, 'stoke_config.json')
84
+
85
+ with open(config_path, 'r') as f:
86
+ config_data = json.load(f)
87
+ model_id = config_data.get('model_id')
88
+ if model_id:
89
+ dataset_name = os.path.basename(os.path.dirname(config_path))
90
+
91
+ with open(stoke_config_path, 'r') as f:
92
+ stoke_config_data = json.load(f)
93
+ if model_id:
94
+ dataset_name = os.path.basename(os.path.dirname(stoke_config_path))
95
+ datasets.setdefault(model_id, {})[dataset_name] = stoke_config_data
96
+ return datasets
97
+
98
+ def filter_spans(spans_and_values):
99
+ if spans_and_values == []:
100
+ return [], []
101
+ # Create a dictionary to store spans based on their second index values
102
+ span_dict = {}
103
+
104
+ spans, values = [x[0] for x in spans_and_values], [x[1] for x in spans_and_values]
105
+
106
+ # Iterate through the spans and update the dictionary with the highest value
107
+ for span, value in zip(spans, values):
108
+ start, end = span
109
+ if start > end or end - start > 15 or start == 0:
110
+ continue
111
+ current_value = span_dict.get(end, None)
112
+
113
+ if current_value is None or current_value[1] < value:
114
+ span_dict[end] = (span, value)
115
+
116
+ if span_dict == {}:
117
+ return [], []
118
+ # Extract the filtered spans and values
119
+ filtered_spans, filtered_values = zip(*span_dict.values())
120
+
121
+ return list(filtered_spans), list(filtered_values)
122
+
123
+ def remove_overlapping_spans(spans):
124
+ # Sort the spans based on their end points
125
+ sorted_spans = sorted(spans, key=lambda x: x[0][1])
126
+
127
+ non_overlapping_spans = []
128
+ last_end = float('-inf')
129
+
130
+ # Iterate through the sorted spans
131
+ for span in sorted_spans:
132
+ start, end = span[0]
133
+ value = span[1]
134
+
135
+ # If the current span does not overlap with the previous one
136
+ if start >= last_end:
137
+ non_overlapping_spans.append(span)
138
+ last_end = end
139
+ else:
140
+ # If it overlaps, choose the one with the highest value
141
+ existing_span_index = -1
142
+ for i, existing_span in enumerate(non_overlapping_spans):
143
+ if existing_span[0][1] <= start:
144
+ existing_span_index = i
145
+ break
146
+ if existing_span_index != -1 and non_overlapping_spans[existing_span_index][1] < value:
147
+ non_overlapping_spans[existing_span_index] = span
148
+
149
+ return non_overlapping_spans
150
+
151
+ def generate_html_no_overlap(tokenized_text, spans):
152
+ current_index = 0
153
+ html_content = ""
154
+
155
+ for (span_start, span_end), value in spans:
156
+ # Add text before the span
157
+ html_content += "".join(tokenized_text[current_index:span_start])
158
+
159
+ # Add the span with underlining
160
+ html_content += "<b><u>"
161
+ html_content += "".join(tokenized_text[span_start:span_end])
162
+ html_content += "</u></b> "
163
+
164
+ current_index = span_end
165
+
166
+ # Add any remaining text after the last span
167
+ html_content += "".join(tokenized_text[current_index:])
168
+
169
+ return html_content
170
+
171
+
172
+ css = """
173
+ <style>
174
+ .prose {
175
+ line-height: 200%;
176
+ }
177
+ .highlight {
178
+ display: inline;
179
+ }
180
+ .highlight::after {
181
+ background-color: var(data-color);
182
+ }
183
+ .spanhighlight {
184
+ padding: 2px 5px;
185
+ border-radius: 5px;
186
+ }
187
+ .tooltip {
188
+ position: relative;
189
+ display: inline-block;
190
+ }
191
+ .generated-content {
192
+ margin-top: -1em;
193
+ height: 130px;
194
+ }
195
+ .tooltip::after {
196
+ content: attr(data-tooltip-text); /* Set content from data-tooltip-text attribute */
197
+ display: none;
198
+ position: absolute;
199
+ background-color: #333;
200
+ color: #fff;
201
+ padding: 5px;
202
+ border-radius: 5px;
203
+ bottom: 100%; /* Position it above the element */
204
+ left: 50%;
205
+ transform: translateX(-50%);
206
+ width: auto;
207
+ min-width: 120px;
208
+ margin: 0 auto;
209
+ text-align: center;
210
+ }
211
+
212
+ .tooltip:hover::after {
213
+ display: block; /* Show the tooltip on hover */
214
+ }
215
+
216
+ .small-text {
217
+ padding: 2px 5px;
218
+ background-color: white;
219
+ border-radius: 5px;
220
+ font-size: xx-small;
221
+ margin-left: 0.5em;
222
+ vertical-align: 0.2em;
223
+ font-weight: bold;
224
+ color: grey!important;
225
+ }
226
+ </style>"""
227
+
228
+
229
+ def generate_html_spanwise(token_strings, tokenwise_preds, spans, tokenizer, new_tags):
230
+
231
+ # spanwise annotated text
232
+ annotated = []
233
+ span_ends = -1
234
+ in_span = False
235
+
236
+ out_of_span_tokens = []
237
+ for i in reversed(range(len(tokenwise_preds))):
238
+
239
+ if in_span:
240
+ if i >= span_ends:
241
+ continue
242
+ else:
243
+ in_span = False
244
+
245
+ predicted_class = ""
246
+ style = ""
247
+
248
+ span = None
249
+ for s in spans:
250
+ if s[1] == i+1:
251
+ span = s
252
+
253
+ if tokenwise_preds[i] != 0 and span is not None:
254
+ predicted_class = f"highlight spanhighlight"
255
+ style = f"background-color: {map_value_to_color((tokenwise_preds[i]-1)/(len(new_tags)-1))}"
256
+ if tokenizer.convert_tokens_to_string([token_strings[i]]).startswith(" "):
257
+ annotated.append("Ġ")
258
+
259
+ span_opener = f"Ġ<span class='{predicted_class}' data-tooltip-text='{new_tags[tokenwise_preds[i]]}' style='{style}'>".replace(" ", "Ġ")
260
+ span_end = f"<span class='small-text'>{new_tags[tokenwise_preds[i]]}</span></span>"
261
+ annotated.extend(out_of_span_tokens)
262
+ out_of_span_tokens = []
263
+ span_ends = span[0]
264
+ in_span = True
265
+ annotated.append(span_end)
266
+ annotated.extend([token_strings[x] for x in reversed(range(span[0], span[1]))])
267
+ annotated.append(span_opener)
268
+ else:
269
+ out_of_span_tokens.append(token_strings[i])
270
+
271
+ annotated.extend(out_of_span_tokens)
272
+
273
+ return [x for x in reversed(annotated)]
274
+
275
+ def gen_json(input_text, max_new_tokens):
276
+ streamer = STOKEStreamer(tok, classifier_token, classifier_span)
277
+
278
+ new_tags = label_map
279
+
280
+ inputs = tok([f" {input_text}"], return_tensors="pt").to(model.device)
281
+ generation_kwargs = dict(
282
+ inputs, streamer=streamer, max_new_tokens=max_new_tokens,
283
+ repetition_penalty=1.2, do_sample=False
284
+ )
285
+
286
+ def generate_async():
287
+ model.generate(**generation_kwargs)
288
+
289
+ thread = Thread(target=generate_async)
290
+ thread.start()
291
+
292
+ # Display generated text as it becomes available
293
+ output_text = ""
294
+ text_tokenwise = ""
295
+ text_spans = ""
296
+ removed_spans = ""
297
+ tags = []
298
+ spans = []
299
+ for new_text in streamer:
300
+ if new_text[1] is not None and new_text[2] != ['']:
301
+ text_tokenwise = ""
302
+ output_text = ""
303
+ tags.extend(new_text[1])
304
+ spans.extend(new_text[-1])
305
+
306
+ # Tokenwise Classification
307
+ for tk, pred in zip(new_text[2],tags):
308
+ if pred != 0:
309
+ style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}"
310
+ if tk.startswith(" "):
311
+ text_tokenwise += " "
312
+ text_tokenwise += f"<span class='tooltip highlight' data-tooltip-text='{new_tags[pred]}' style='{style}'>{tk}</span>"
313
+ output_text += tk
314
+ else:
315
+ text_tokenwise += tk
316
+ output_text += tk
317
+
318
+ # Span Classification
319
+ text_spans = ""
320
+ if len(spans) > 0:
321
+ filtered_spans = remove_overlapping_spans(spans)
322
+ text_spans = generate_html_no_overlap(new_text[2], filtered_spans)
323
+ if len(spans) - len(filtered_spans) > 0:
324
+ removed_spans = f"{len(spans) - len(filtered_spans)} span(s) hidden due to overlap."
325
+ else:
326
+ for tk in new_text[2]:
327
+ text_spans += f"{tk}"
328
+
329
+ # Spanwise Classification
330
+ annotated_tokens = generate_html_spanwise(new_text[2], tags, [x for x in filter_spans(spans)[0]], tok, new_tags)
331
+ generated_text_spanwise = tok.convert_tokens_to_string(annotated_tokens).replace("<|endoftext|>", "").replace("<|begin_of_text|>", "")
332
+
333
+ output = f"{css}<br>"
334
+ output += generated_text_spanwise.replace("\n", " ").replace("$", "$") + "\n<br>"
335
+ #output += "<h5>Show tokenwise classification</h5>\n" + text_tokenwise.replace("\n", " ").replace("$", "\\$").replace("<|endoftext|>", "").replace("<|begin_of_text|>", "")
336
+ #output += "</details><details><summary>Show spans</summary>\n" + text_spans.replace("\n", " ").replace("$", "\\$")
337
+ #if removed_spans != "":
338
+ # output += f"<br><br><i>({removed_spans})</i>"
339
+ list_of_spans = [{"name": tok.convert_tokens_to_string(new_text[2][x[0]:x[1]]).strip(), "type": new_tags[tags[x[1]-1]]} for x in filter_spans(spans)[0] if new_tags[tags[x[1]-1]] != "O"]
340
+
341
+ out_dict = {"text": output_text.replace("<|endoftext|>", "").replace("<|begin_of_text|>", "".strip()), "entites": list_of_spans}
342
+
343
+ yield out_dict
344
+ return
345
+
346
+ # Gradio app function to generate text using the assigned model instance
347
+ def generate_text(input_text, max_new_tokens=40):
348
+ if input_text == "":
349
+ yield "Please enter some text first."
350
+ return
351
+
352
+ # Select the next model instance in a round-robin manner
353
+ model, tok = next(model_round_robin)
354
+
355
+ generate_button.visible = False
356
+ streamer = STOKEStreamer(tok, classifier_token, classifier_span)
357
+
358
+ new_tags = label_map
359
+
360
+ inputs = tok([f" {input_text[:200]}"], return_tensors="pt").to(model.device)
361
+ generation_kwargs = dict(
362
+ inputs, streamer=streamer, max_new_tokens=max_new_tokens,
363
+ repetition_penalty=1.2, do_sample=False, temperature=None, top_p=None
364
+ )
365
+
366
+ def generate_async():
367
+ model.generate(**generation_kwargs)
368
+
369
+ thread = Thread(target=generate_async)
370
+ thread.start()
371
+
372
+ # Display generated text as it becomes available
373
+ output_text = ""
374
+ text_tokenwise = ""
375
+ text_spans = ""
376
+ removed_spans = ""
377
+ tags = []
378
+ spans = []
379
+ for new_text in streamer:
380
+ if new_text[1] is not None and new_text[2] != ['']:
381
+ text_tokenwise = ""
382
+ output_text = ""
383
+ tags.extend(new_text[1])
384
+ spans.extend(new_text[-1])
385
+
386
+ # Tokenwise Classification
387
+ for tk, pred in zip(new_text[2],tags):
388
+ if pred != 0:
389
+ style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}"
390
+ if tk.startswith(" "):
391
+ text_tokenwise += " "
392
+ text_tokenwise += f"<span class='tooltip highlight' data-tooltip-text='{new_tags[pred]}' style='{style}'>{tk}</span>"
393
+ output_text += tk
394
+ else:
395
+ text_tokenwise += tk
396
+ output_text += tk
397
+
398
+ # Span Classification
399
+ text_spans = ""
400
+ if len(spans) > 0:
401
+ filtered_spans = remove_overlapping_spans(spans)
402
+ text_spans = generate_html_no_overlap(new_text[2], filtered_spans)
403
+ if len(spans) - len(filtered_spans) > 0:
404
+ removed_spans = f"{len(spans) - len(filtered_spans)} span(s) hidden due to overlap."
405
+ else:
406
+ for tk in new_text[2]:
407
+ text_spans += f"{tk}"
408
+
409
+ # Spanwise Classification
410
+ annotated_tokens = generate_html_spanwise(new_text[2], tags, [x for x in filter_spans(spans)[0]], tok, new_tags)
411
+ generated_text_spanwise = tok.convert_tokens_to_string(annotated_tokens).replace("<|endoftext|>", "").replace("<|begin_of_text|>", "")
412
+
413
+ output = f"{css}<div class=\"generated-content\"><br>"
414
+ output += generated_text_spanwise.replace("\n", " ").replace("$", "$") + "\n<br>"
415
+
416
+ list_of_spans = [{"name": tok.convert_tokens_to_string(new_text[2][x[0]:x[1]]).strip(), "type": new_tags[tags[x[1]-1]]} for x in filter_spans(spans)[0] if new_tags[tags[x[1]-1]] != "O"]
417
+
418
+ out_dict = {"text": output_text.replace("<|endoftext|>", "").replace("<|begin_of_text|>", "").strip(), "entites": list_of_spans}
419
+
420
+ yield output + "</div>"
421
+
422
+ generate_button.visible = True
423
+ return
424
+
425
+
426
+ # Load datasets and models for the Gradio app
427
+ datasets = find_datasets_and_model_ids("data/")
428
+ available_models = list(datasets.keys())
429
+ available_datasets = {model: list(datasets[model].keys()) for model in available_models}
430
+ available_configs = {model: {dataset: list(datasets[model][dataset].keys()) for dataset in available_datasets[model]} for model in available_models}
431
+
432
+ def update_datasets(model_name):
433
+ return available_datasets[model_name]
434
+
435
+ def update_configs(model_name, dataset_name):
436
+ return available_configs[model_name][dataset_name]
437
+
438
+ # Load datasets and models for the Gradio app
439
+ datasets = find_datasets_and_model_ids("data/")
440
+ available_models = list(datasets.keys())
441
+ available_datasets = {model: list(datasets[model].keys()) for model in available_models}
442
+ available_configs = {model: {dataset: list(datasets[model][dataset].keys()) for dataset in available_datasets[model]} for model in available_models}
443
+
444
+ # Set the model ID and data configurations
445
+ model_id = "meta-llama/Llama-3.2-1B"
446
+ data_id = "STOKE_100"
447
+ config_id = "default"
448
+
449
+ # Load n_instances separate instances of the model and tokenizer
450
+ model_instances = get_multiple_model_and_tokenizer(model_id, n_instances)
451
+
452
+ # Set up the round-robin iterator to distribute the requests across model instances
453
+ model_round_robin = itertools.cycle(model_instances)
454
+
455
+
456
+ # Load model classifiers
457
+ try:
458
+ classifier_span, classifier_token, label_map = get_classifiers_for_model(
459
+ model_instances[0][0].config.n_head * model_instances[0][0].config.n_layer, model_instances[0][0].config.n_embd, model_instances[0][0].device,
460
+ datasets[model_id][data_id][config_id]
461
+ )
462
+ except:
463
+ classifier_span, classifier_token, label_map = get_classifiers_for_model(
464
+ model_instances[0][0].config.num_attention_heads * model_instances[0][0].config.num_hidden_layers, model_instances[0][0].config.hidden_size, model_instances[0][0].device,
465
+ datasets[model_id][data_id][config_id]
466
+ )
467
+
468
+ initial_output = (css+"""<div class=\"generated-content\"><br><style>
469
+ .prose {
470
+ line-height: 200%;
471
+ }
472
+ .highlight {
473
+ display: inline;
474
+ }
475
+ .highlight::after {
476
+ background-color: var(data-color);
477
+ }
478
+ .spanhighlight {
479
+ padding: 2px 5px;
480
+ border-radius: 5px;
481
+ }
482
+ .tooltip {
483
+ position: relative;
484
+ display: inline-block;
485
+ }
486
+
487
+ .tooltip::after {
488
+ content: attr(data-tooltip-text); /* Set content from data-tooltip-text attribute */
489
+ display: none;
490
+ position: absolute;
491
+ background-color: #333;
492
+ color: #fff;
493
+ padding: 5px;
494
+ border-radius: 5px;
495
+ bottom: 100%; /* Position it above the element */
496
+ left: 50%;
497
+ transform: translateX(-50%);
498
+ width: auto;
499
+ min-width: 120px;
500
+ margin: 0 auto;
501
+ text-align: center;
502
+ }
503
+
504
+ .tooltip:hover::after {
505
+ display: block; /* Show the tooltip on hover */
506
+ }
507
+
508
+ .small-text {
509
+ padding: 2px 5px;
510
+ background-color: white;
511
+ border-radius: 5px;
512
+ font-size: xx-small;
513
+ margin-left: 0.5em;
514
+ vertical-align: 0.2em;
515
+ font-weight: bold;
516
+ color: grey!important;
517
+ }
518
+ </style><span class='highlight spanhighlight' data-tooltip-text='GPE' style='background-color: #e6550d88'> Miami<span class='small-text'>GPE</span></span> is a city in the <span class='highlight spanhighlight' data-tooltip-text='GPE' style='background-color: #e6550d88'> U.S.<span class='small-text'>GPE</span></span> state of <span class='highlight spanhighlight' data-tooltip-text='GPE' style='background-color: #e6550d88'> Florida<span class='small-text'>GPE</span></span>, and it's also known as " <span class='highlight spanhighlight' data-tooltip-text='WORK_OF_ART' style='background-color: #bdbdbd88'>The Magic City<span class='small-text'>WORK_OF_ART</span></span>." It was founded by <span class='highlight spanhighlight' data-tooltip-text='PERSON' style='background-color: #bcbddc88'> Henry Flagler<span class='small-text'>PERSON</span></span> on <span class='highlight spanhighlight' data-tooltip-text='DATE' style='background-color: #6baed688'> October 28th, 1896<span class='small-text'>DATE</span></span>.
519
+ <br></div>""", {'text': 'Miami is a city in the U.S. state of Florida, and it\'s also known as "The Magic City." It was founded by Henry Flagler on October 28th, 1896.', 'entites': [{'name': 'Miami', 'type': 'GPE'}, {'name': 'U.S.', 'type': 'GPE'}, {'name': 'Florida', 'type': 'GPE'}, {'name': 'The Magic City', 'type': 'WORK_OF_ART'}, {'name': 'Henry Flagler', 'type': 'PERSON'}, {'name': 'October 28th, 1896', 'type': 'DATE'}]})
520
+
521
+
522
+ with gr.Blocks(css="footer{display:none !important} .gradio-container {padding: 0!important; height:400px;}", fill_width=True) as demo:
523
+ with gr.Tab("EMBER Demo"):
524
+ with gr.Row():
525
+ output_text = gr.HTML(label="Generated Text", value=initial_output[0])
526
+ with gr.Group():
527
+ with gr.Row():
528
+ input_text = gr.Textbox(label="Enter prompt for completion", value="Miami is", max_length=200)
529
+ generate_button = gr.Button("Generate", scale=0)
530
+ # New HTML output for model info
531
+ model_info_html = gr.HTML(
532
+ label="Model Info",
533
+ value=f'<div style="font-weight: lighter; text-align: center; margin-top: -1.5em; margin-bottom: -1em!important; font-size: x-small;">{model_id} running on {gpu_name}</div>'
534
+ )
535
+
536
+
537
+ generate_button.click(
538
+ fn=generate_text,
539
+ inputs=[input_text],
540
+ outputs=[output_text],
541
+ concurrency_limit=n_instances,
542
+ concurrency_id="queue"
543
+ )
544
+
545
+ # Function to refresh the model info HTML
546
+ def refresh_model_info():
547
+ return f'<div style="font-weight: lighter; text-align: center; margin-top: -1.5em; margin-bottom: -1em!important; font-size: x-small;">{model_id} running on {gpu_name}</div>'
548
+
549
+ # Update the model info HTML on button click
550
+ generate_button.click(
551
+ fn=refresh_model_info,
552
+ inputs=[],
553
+ outputs=[model_info_html],
554
+ queue=False
555
+ )
556
+
557
+
558
+ demo.queue()
559
+
560
+ demo.launch(server_name="0.0.0.0", server_port=7860)
data/meta-llama/Llama-3.2-1B/STOKE_100/checkpoints/span_classifier/Rxi8b70XJA/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25951d9b73437a7aa344f4c207cbda2f88d9bf5fa94d1a779617948b18a1c4ed
3
+ size 8439912
data/meta-llama/Llama-3.2-1B/STOKE_100/checkpoints/span_classifier/Rxi8b70XJA/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": "meta-llama/Llama-3.2-1B",
3
+ "type": "span_classifier",
4
+ "label_map": [
5
+ "no_span",
6
+ "span"
7
+ ],
8
+ "learning_rate": 0.0003,
9
+ "classifier_dim": 4096,
10
+ "loss_weights": [
11
+ 1.0,
12
+ 1.0
13
+ ],
14
+ "identifier": "Rxi8b70XJA",
15
+ "best_f1_validation": 0.8677362203598022,
16
+ "best_f1_validation_classwise": {
17
+ "span": {
18
+ "p": 0.896858811378479,
19
+ "r": 0.8404456377029419,
20
+ "f": 0.867736279964447,
21
+ "s": 24324.0
22
+ },
23
+ "macro": {
24
+ "p": 0.896858811378479,
25
+ "r": 0.8404456377029419,
26
+ "f": 0.867736279964447
27
+ }
28
+ }
29
+ }
data/meta-llama/Llama-3.2-1B/STOKE_100/checkpoints/span_classifier/Rxi8b70XJA/config_train.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "path": "data/meta-llama/Llama-3.2-1B/STOKE_100",
3
+ "splits": [
4
+ "train",
5
+ "validation"
6
+ ],
7
+ "layers": [
8
+ 8,
9
+ 9,
10
+ 10,
11
+ 11,
12
+ 12
13
+ ],
14
+ "hfcache": "",
15
+ "classifier_dims": [
16
+ 4096
17
+ ],
18
+ "learning_rates": [
19
+ 0.0001,
20
+ 5e-05,
21
+ 0.0003
22
+ ],
23
+ "cuda": true,
24
+ "n_steps_per_epoch": 10000,
25
+ "n_epochs": 30,
26
+ "batch_size": 8,
27
+ "balance_loss": false,
28
+ "loss_weights_span": [
29
+ [
30
+ 1.0,
31
+ 1.0
32
+ ],
33
+ [
34
+ 1.0,
35
+ 50.0
36
+ ],
37
+ [
38
+ 1.0,
39
+ 100.0
40
+ ]
41
+ ],
42
+ "time": 1727765390.5829365,
43
+ "config_dataset": {
44
+ "generation_kwargs": {
45
+ "max_new_tokens": 100,
46
+ "repetition_penalty": 1.2
47
+ },
48
+ "model_id": "meta-llama/Llama-3.2-1B",
49
+ "flair_model_name": "flair/ner-english-ontonotes-large"
50
+ }
51
+ }
data/meta-llama/Llama-3.2-1B/STOKE_100/checkpoints/token_classifier/dR8xQB4ODU/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dce5b3038d8767430a8bba16af61ec6af67c9d1aedc75a9f34c01feebac09b6e
3
+ size 33884328
data/meta-llama/Llama-3.2-1B/STOKE_100/checkpoints/token_classifier/dR8xQB4ODU/config.json ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "layer": 10,
3
+ "model": "meta-llama/Llama-3.2-1B",
4
+ "type": "token_classifier",
5
+ "label_map": [
6
+ "O",
7
+ "CARDINAL",
8
+ "DATE",
9
+ "EVENT",
10
+ "FAC",
11
+ "GPE",
12
+ "LANGUAGE",
13
+ "LAW",
14
+ "LOC",
15
+ "MONEY",
16
+ "NORP",
17
+ "ORDINAL",
18
+ "ORG",
19
+ "PERCENT",
20
+ "PERSON",
21
+ "PRODUCT",
22
+ "QUANTITY",
23
+ "TIME",
24
+ "WORK_OF_ART"
25
+ ],
26
+ "learning_rate": 5e-05,
27
+ "classifier_dim": 4096,
28
+ "loss_weights": [
29
+ 1.0,
30
+ 1.0,
31
+ 1.0,
32
+ 1.0,
33
+ 1.0,
34
+ 1.0,
35
+ 1.0,
36
+ 1.0,
37
+ 1.0,
38
+ 1.0,
39
+ 1.0,
40
+ 1.0,
41
+ 1.0,
42
+ 1.0,
43
+ 1.0,
44
+ 1.0,
45
+ 1.0,
46
+ 1.0,
47
+ 1.0
48
+ ],
49
+ "identifier": "dR8xQB4ODU",
50
+ "best_f1_validation": 0.9056437015533447,
51
+ "best_f1_validation_classwise": {
52
+ "CARDINAL": {
53
+ "p": 0.8679801225662231,
54
+ "r": 0.8777581453323364,
55
+ "f": 0.8728417754173279,
56
+ "s": 10741.0
57
+ },
58
+ "DATE": {
59
+ "p": 0.9519810676574707,
60
+ "r": 0.9389873743057251,
61
+ "f": 0.9454395771026611,
62
+ "s": 8572.0
63
+ },
64
+ "EVENT": {
65
+ "p": 0.8587140440940857,
66
+ "r": 0.8319672346115112,
67
+ "f": 0.8451290726661682,
68
+ "s": 1220.0
69
+ },
70
+ "FAC": {
71
+ "p": 0.8515185713768005,
72
+ "r": 0.8122317790985107,
73
+ "f": 0.8314113020896912,
74
+ "s": 932.0
75
+ },
76
+ "GPE": {
77
+ "p": 0.9000998735427856,
78
+ "r": 0.9094448685646057,
79
+ "f": 0.904748260974884,
80
+ "s": 6935.0
81
+ },
82
+ "LANGUAGE": {
83
+ "p": 0.75,
84
+ "r": 0.7200000286102295,
85
+ "f": 0.7346938848495483,
86
+ "s": 25.0
87
+ },
88
+ "LAW": {
89
+ "p": 0.8709677457809448,
90
+ "r": 0.73828125,
91
+ "f": 0.7991543412208557,
92
+ "s": 256.0
93
+ },
94
+ "LOC": {
95
+ "p": 0.8258426785469055,
96
+ "r": 0.7101449370384216,
97
+ "f": 0.7636363506317139,
98
+ "s": 414.0
99
+ },
100
+ "MONEY": {
101
+ "p": 0.876042902469635,
102
+ "r": 0.8626760840415955,
103
+ "f": 0.8693081140518188,
104
+ "s": 1704.0
105
+ },
106
+ "NORP": {
107
+ "p": 0.9160357713699341,
108
+ "r": 0.887333333492279,
109
+ "f": 0.9014561772346497,
110
+ "s": 1500.0
111
+ },
112
+ "ORDINAL": {
113
+ "p": 0.9303238391876221,
114
+ "r": 0.9498997926712036,
115
+ "f": 0.9400099515914917,
116
+ "s": 998.0
117
+ },
118
+ "ORG": {
119
+ "p": 0.8974575400352478,
120
+ "r": 0.8792765140533447,
121
+ "f": 0.8882739543914795,
122
+ "s": 9675.0
123
+ },
124
+ "PERCENT": {
125
+ "p": 0.8629592657089233,
126
+ "r": 0.8083720803260803,
127
+ "f": 0.8347742557525635,
128
+ "s": 1075.0
129
+ },
130
+ "PERSON": {
131
+ "p": 0.9707135558128357,
132
+ "r": 0.9713156223297119,
133
+ "f": 0.9710144996643066,
134
+ "s": 12899.0
135
+ },
136
+ "PRODUCT": {
137
+ "p": 0.7828418016433716,
138
+ "r": 0.7564767003059387,
139
+ "f": 0.7694334387779236,
140
+ "s": 386.0
141
+ },
142
+ "QUANTITY": {
143
+ "p": 0.8409090638160706,
144
+ "r": 0.7758846879005432,
145
+ "f": 0.8070893287658691,
146
+ "s": 763.0
147
+ },
148
+ "TIME": {
149
+ "p": 0.8710959553718567,
150
+ "r": 0.8373362421989441,
151
+ "f": 0.8538825511932373,
152
+ "s": 1832.0
153
+ },
154
+ "WORK_OF_ART": {
155
+ "p": 0.7803030014038086,
156
+ "r": 0.7152777910232544,
157
+ "f": 0.7463768124580383,
158
+ "s": 576.0
159
+ },
160
+ "macro": {
161
+ "p": 0.8669881820678711,
162
+ "r": 0.8323702216148376,
163
+ "f": 0.8488152027130127
164
+ }
165
+ }
166
+ }
data/meta-llama/Llama-3.2-1B/STOKE_100/checkpoints/token_classifier/dR8xQB4ODU/config_train.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "path": "data/meta-llama/Llama-3.2-1B/STOKE_100",
3
+ "splits": [
4
+ "train",
5
+ "validation"
6
+ ],
7
+ "layers": [
8
+ 8,
9
+ 9,
10
+ 10,
11
+ 11,
12
+ 12
13
+ ],
14
+ "hfcache": "",
15
+ "classifier_dims": [
16
+ 4096
17
+ ],
18
+ "learning_rates": [
19
+ 0.0001,
20
+ 5e-05,
21
+ 0.0003
22
+ ],
23
+ "cuda": true,
24
+ "n_steps_per_epoch": 10000,
25
+ "n_epochs": 30,
26
+ "batch_size": 8,
27
+ "balance_loss": false,
28
+ "loss_weights_span": [
29
+ [
30
+ 1.0,
31
+ 1.0
32
+ ],
33
+ [
34
+ 1.0,
35
+ 50.0
36
+ ],
37
+ [
38
+ 1.0,
39
+ 100.0
40
+ ]
41
+ ],
42
+ "time": 1727765390.5829365,
43
+ "config_dataset": {
44
+ "generation_kwargs": {
45
+ "max_new_tokens": 100,
46
+ "repetition_penalty": 1.2
47
+ },
48
+ "model_id": "meta-llama/Llama-3.2-1B",
49
+ "flair_model_name": "flair/ner-english-ontonotes-large"
50
+ }
51
+ }
data/meta-llama/Llama-3.2-1B/STOKE_100/checkpoints/token_classifier/pbK46jjAVx/checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f32816959f5fd27967c754a61b07d8ae6c92b7881e2fbb6a68b54b8c0c575122
3
+ size 33884328
data/meta-llama/Llama-3.2-1B/STOKE_100/checkpoints/token_classifier/pbK46jjAVx/config.json ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "layer": 10,
3
+ "model": "meta-llama/Llama-3.2-1B",
4
+ "type": "token_classifier",
5
+ "label_map": [
6
+ "O",
7
+ "CARDINAL",
8
+ "DATE",
9
+ "EVENT",
10
+ "FAC",
11
+ "GPE",
12
+ "LANGUAGE",
13
+ "LAW",
14
+ "LOC",
15
+ "MONEY",
16
+ "NORP",
17
+ "ORDINAL",
18
+ "ORG",
19
+ "PERCENT",
20
+ "PERSON",
21
+ "PRODUCT",
22
+ "QUANTITY",
23
+ "TIME",
24
+ "WORK_OF_ART"
25
+ ],
26
+ "learning_rate": 0.0003,
27
+ "classifier_dim": 4096,
28
+ "loss_weights": [
29
+ 1.0,
30
+ 1.0,
31
+ 1.0,
32
+ 1.0,
33
+ 1.0,
34
+ 1.0,
35
+ 1.0,
36
+ 1.0,
37
+ 1.0,
38
+ 1.0,
39
+ 1.0,
40
+ 1.0,
41
+ 1.0,
42
+ 1.0,
43
+ 1.0,
44
+ 1.0,
45
+ 1.0,
46
+ 1.0,
47
+ 1.0
48
+ ],
49
+ "identifier": "pbK46jjAVx",
50
+ "best_f1_validation": 0.9048610329627991,
51
+ "best_f1_validation_classwise": {
52
+ "CARDINAL": {
53
+ "p": 0.8730558156967163,
54
+ "r": 0.8727306723594666,
55
+ "f": 0.872893214225769,
56
+ "s": 10741.0
57
+ },
58
+ "DATE": {
59
+ "p": 0.9534441828727722,
60
+ "r": 0.9365375638008118,
61
+ "f": 0.944915235042572,
62
+ "s": 8572.0
63
+ },
64
+ "EVENT": {
65
+ "p": 0.8540268540382385,
66
+ "r": 0.83442622423172,
67
+ "f": 0.844112753868103,
68
+ "s": 1220.0
69
+ },
70
+ "FAC": {
71
+ "p": 0.8227027058601379,
72
+ "r": 0.8165236115455627,
73
+ "f": 0.8196015357971191,
74
+ "s": 932.0
75
+ },
76
+ "GPE": {
77
+ "p": 0.9014912247657776,
78
+ "r": 0.9065608978271484,
79
+ "f": 0.9040189981460571,
80
+ "s": 6935.0
81
+ },
82
+ "LANGUAGE": {
83
+ "p": 0.7272727489471436,
84
+ "r": 0.6399999856948853,
85
+ "f": 0.6808510422706604,
86
+ "s": 25.0
87
+ },
88
+ "LAW": {
89
+ "p": 0.8500000238418579,
90
+ "r": 0.73046875,
91
+ "f": 0.7857142686843872,
92
+ "s": 256.0
93
+ },
94
+ "LOC": {
95
+ "p": 0.8867924809455872,
96
+ "r": 0.6811594367027283,
97
+ "f": 0.7704918384552002,
98
+ "s": 414.0
99
+ },
100
+ "MONEY": {
101
+ "p": 0.873665452003479,
102
+ "r": 0.8644366264343262,
103
+ "f": 0.8690265417098999,
104
+ "s": 1704.0
105
+ },
106
+ "NORP": {
107
+ "p": 0.9220505356788635,
108
+ "r": 0.875333309173584,
109
+ "f": 0.898084819316864,
110
+ "s": 1500.0
111
+ },
112
+ "ORDINAL": {
113
+ "p": 0.9244186282157898,
114
+ "r": 0.9559118151664734,
115
+ "f": 0.9399014711380005,
116
+ "s": 998.0
117
+ },
118
+ "ORG": {
119
+ "p": 0.8920637965202332,
120
+ "r": 0.8841343522071838,
121
+ "f": 0.888081431388855,
122
+ "s": 9675.0
123
+ },
124
+ "PERCENT": {
125
+ "p": 0.8530852198600769,
126
+ "r": 0.8102325797080994,
127
+ "f": 0.8311069011688232,
128
+ "s": 1075.0
129
+ },
130
+ "PERSON": {
131
+ "p": 0.9692212343215942,
132
+ "r": 0.9716256856918335,
133
+ "f": 0.9704219698905945,
134
+ "s": 12899.0
135
+ },
136
+ "PRODUCT": {
137
+ "p": 0.7886179089546204,
138
+ "r": 0.7538859844207764,
139
+ "f": 0.7708609104156494,
140
+ "s": 386.0
141
+ },
142
+ "QUANTITY": {
143
+ "p": 0.8215258717536926,
144
+ "r": 0.7903014421463013,
145
+ "f": 0.8056111931800842,
146
+ "s": 763.0
147
+ },
148
+ "TIME": {
149
+ "p": 0.8752886652946472,
150
+ "r": 0.8275108933448792,
151
+ "f": 0.8507295250892639,
152
+ "s": 1832.0
153
+ },
154
+ "WORK_OF_ART": {
155
+ "p": 0.7937743067741394,
156
+ "r": 0.7083333134651184,
157
+ "f": 0.7486238479614258,
158
+ "s": 576.0
159
+ },
160
+ "macro": {
161
+ "p": 0.8656943440437317,
162
+ "r": 0.8255618214607239,
163
+ "f": 0.8441693186759949
164
+ }
165
+ }
166
+ }
data/meta-llama/Llama-3.2-1B/STOKE_100/checkpoints/token_classifier/pbK46jjAVx/config_train.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "path": "data/meta-llama/Llama-3.2-1B/STOKE_100",
3
+ "splits": [
4
+ "train",
5
+ "validation"
6
+ ],
7
+ "layers": [
8
+ 8,
9
+ 9,
10
+ 10,
11
+ 11,
12
+ 12
13
+ ],
14
+ "hfcache": "",
15
+ "classifier_dims": [
16
+ 4096
17
+ ],
18
+ "learning_rates": [
19
+ 0.0001,
20
+ 5e-05,
21
+ 0.0003
22
+ ],
23
+ "cuda": true,
24
+ "n_steps_per_epoch": 10000,
25
+ "n_epochs": 30,
26
+ "batch_size": 8,
27
+ "balance_loss": false,
28
+ "loss_weights_span": [
29
+ [
30
+ 1.0,
31
+ 1.0
32
+ ],
33
+ [
34
+ 1.0,
35
+ 50.0
36
+ ],
37
+ [
38
+ 1.0,
39
+ 100.0
40
+ ]
41
+ ],
42
+ "time": 1727765390.5829365,
43
+ "config_dataset": {
44
+ "generation_kwargs": {
45
+ "max_new_tokens": 100,
46
+ "repetition_penalty": 1.2
47
+ },
48
+ "model_id": "meta-llama/Llama-3.2-1B",
49
+ "flair_model_name": "flair/ner-english-ontonotes-large"
50
+ }
51
+ }
data/meta-llama/Llama-3.2-1B/STOKE_100/config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "generation_kwargs": {
3
+ "max_new_tokens": 100,
4
+ "repetition_penalty": 1.2
5
+ },
6
+ "model_id": "meta-llama/Llama-3.2-1B",
7
+ "flair_model_name": "flair/ner-english-ontonotes-large"
8
+ }
data/meta-llama/Llama-3.2-1B/STOKE_100/stoke_config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "default": {
3
+ "classifier_token": "data/meta-llama/Llama-3.2-1B/STOKE_100/checkpoints/token_classifier/pbK46jjAVx",
4
+ "classifier_span": "data/meta-llama/Llama-3.2-1B/STOKE_100/checkpoints/span_classifier/Rxi8b70XJA"
5
+ },
6
+ "basic": {
7
+ "classifier_token": "data/meta-llama/Llama-3.2-1B/STOKE_100/checkpoints/token_classifier/dR8xQB4ODU",
8
+ "classifier_span": "data/meta-llama/Llama-3.2-1B/STOKE_100/checkpoints/span_classifier/Rxi8b70XJA"
9
+ }
10
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/nicpopovic/transformers.git@4.45-STOKE
2
+ torch
3
+ matplotlib
4
+ flair
5
+ nltk
6
+ datasets
7
+ torcheval
8
+ gradio