nicpopovic commited on
Commit
a1c9350
·
verified ·
1 Parent(s): b244637

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +495 -0
app.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, STOKEStreamer
3
+ from threading import Thread
4
+ import json
5
+ import torch
6
+ import os
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from matplotlib.colors import to_hex
10
+ from bs4 import BeautifulSoup
11
+
12
+ def clean_html(html_content):
13
+ # Parse the HTML
14
+ soup = BeautifulSoup(html_content, 'html.parser')
15
+
16
+ # Remove all elements with class 'small-text'
17
+ for element in soup.find_all(class_='small-text'):
18
+ element.decompose() # Removes the element from the tree
19
+
20
+ # Get the plain text, stripping any remaining HTML tags
21
+ cleaned_text = soup.get_text()
22
+
23
+ return cleaned_text.strip().replace(" ", " ").replace("( ", "(").replace(" )", ")")
24
+
25
+ # Reusing the original MLP class and other functions (unchanged) except those specific to Streamlit
26
+ class MLP(torch.nn.Module):
27
+ def __init__(self, input_dim, output_dim, hidden_dim=1024, layer_id=0, cuda=False):
28
+ super(MLP, self).__init__()
29
+ self.fc1 = torch.nn.Linear(input_dim, hidden_dim)
30
+ self.fc3 = torch.nn.Linear(hidden_dim, output_dim)
31
+ self.layer_id = layer_id
32
+ if cuda:
33
+ self.device = "cuda"
34
+ else:
35
+ self.device = "cpu"
36
+ self.to(self.device)
37
+
38
+ def forward(self, x):
39
+ x = torch.flatten(x, start_dim=1)
40
+ x = torch.relu(self.fc1(x))
41
+ x = self.fc3(x)
42
+ return torch.argmax(x, dim=-1).cpu().detach(), torch.softmax(x, dim=-1).cpu().detach()
43
+
44
+ def map_value_to_color(value, colormap_name='tab20c'):
45
+ value = np.clip(value, 0.0, 1.0)
46
+ colormap = plt.get_cmap(colormap_name)
47
+ rgba_color = colormap(value)
48
+ css_color = to_hex(rgba_color)
49
+ return css_color + "88"
50
+
51
+ # Caching functions for model and classifier
52
+ model_cache = {}
53
+
54
+ def get_model_and_tokenizer(name):
55
+ if name not in model_cache:
56
+ tok = AutoTokenizer.from_pretrained(name, token=os.getenv('HF_TOKEN'))
57
+ model = AutoModelForCausalLM.from_pretrained(name, token=os.getenv('HF_TOKEN'), torch_dtype="bfloat16")
58
+ #model = AutoModelForCausalLM.from_pretrained(name, token=, load_in_4bit=True)
59
+ model_cache[name] = (model, tok)
60
+ return model_cache[name]
61
+
62
+ def get_classifiers_for_model(att_size, emb_size, device, config_paths):
63
+ config = {
64
+ "classifier_token": json.load(open(os.path.join(config_paths["classifier_token"], "config.json"), "r")),
65
+ "classifier_span": json.load(open(os.path.join(config_paths["classifier_span"], "config.json"), "r"))
66
+ }
67
+ layer_id = config["classifier_token"]["layer"]
68
+
69
+ classifier_span = MLP(att_size, 2, hidden_dim=config["classifier_span"]["classifier_dim"]).to(device)
70
+ classifier_span.load_state_dict(torch.load(os.path.join(config_paths["classifier_span"], "checkpoint.pt"), map_location=device))
71
+
72
+ classifier_token = MLP(emb_size, len(config["classifier_token"]["label_map"]), layer_id=layer_id, hidden_dim=config["classifier_token"]["classifier_dim"]).to(device)
73
+ classifier_token.load_state_dict(torch.load(os.path.join(config_paths["classifier_token"], "checkpoint.pt"), map_location=device))
74
+
75
+ return classifier_span, classifier_token, config["classifier_token"]["label_map"]
76
+
77
+ def find_datasets_and_model_ids(root_dir):
78
+ datasets = {}
79
+ for root, dirs, files in os.walk(root_dir):
80
+ if 'config.json' in files and 'stoke_config.json' in files:
81
+ config_path = os.path.join(root, 'config.json')
82
+ stoke_config_path = os.path.join(root, 'stoke_config.json')
83
+
84
+ with open(config_path, 'r') as f:
85
+ config_data = json.load(f)
86
+ model_id = config_data.get('model_id')
87
+ if model_id:
88
+ dataset_name = os.path.basename(os.path.dirname(config_path))
89
+
90
+ with open(stoke_config_path, 'r') as f:
91
+ stoke_config_data = json.load(f)
92
+ if model_id:
93
+ dataset_name = os.path.basename(os.path.dirname(stoke_config_path))
94
+ datasets.setdefault(model_id, {})[dataset_name] = stoke_config_data
95
+ return datasets
96
+
97
+ def filter_spans(spans_and_values):
98
+ if spans_and_values == []:
99
+ return [], []
100
+ # Create a dictionary to store spans based on their second index values
101
+ span_dict = {}
102
+
103
+ spans, values = [x[0] for x in spans_and_values], [x[1] for x in spans_and_values]
104
+
105
+ # Iterate through the spans and update the dictionary with the highest value
106
+ for span, value in zip(spans, values):
107
+ start, end = span
108
+ if start > end or end - start > 15 or start == 0:
109
+ continue
110
+ current_value = span_dict.get(end, None)
111
+
112
+ if current_value is None or current_value[1] < value:
113
+ span_dict[end] = (span, value)
114
+
115
+ if span_dict == {}:
116
+ return [], []
117
+ # Extract the filtered spans and values
118
+ filtered_spans, filtered_values = zip(*span_dict.values())
119
+
120
+ return list(filtered_spans), list(filtered_values)
121
+
122
+ def remove_overlapping_spans(spans):
123
+ # Sort the spans based on their end points
124
+ sorted_spans = sorted(spans, key=lambda x: x[0][1])
125
+
126
+ non_overlapping_spans = []
127
+ last_end = float('-inf')
128
+
129
+ # Iterate through the sorted spans
130
+ for span in sorted_spans:
131
+ start, end = span[0]
132
+ value = span[1]
133
+
134
+ # If the current span does not overlap with the previous one
135
+ if start >= last_end:
136
+ non_overlapping_spans.append(span)
137
+ last_end = end
138
+ else:
139
+ # If it overlaps, choose the one with the highest value
140
+ existing_span_index = -1
141
+ for i, existing_span in enumerate(non_overlapping_spans):
142
+ if existing_span[0][1] <= start:
143
+ existing_span_index = i
144
+ break
145
+ if existing_span_index != -1 and non_overlapping_spans[existing_span_index][1] < value:
146
+ non_overlapping_spans[existing_span_index] = span
147
+
148
+ return non_overlapping_spans
149
+
150
+ def generate_html_no_overlap(tokenized_text, spans):
151
+ current_index = 0
152
+ html_content = ""
153
+
154
+ for (span_start, span_end), value in spans:
155
+ # Add text before the span
156
+ html_content += "".join(tokenized_text[current_index:span_start])
157
+
158
+ # Add the span with underlining
159
+ html_content += "<b><u>"
160
+ html_content += "".join(tokenized_text[span_start:span_end])
161
+ html_content += "</u></b> "
162
+
163
+ current_index = span_end
164
+
165
+ # Add any remaining text after the last span
166
+ html_content += "".join(tokenized_text[current_index:])
167
+
168
+ return html_content
169
+
170
+
171
+ def generate_html_spanwise(token_strings, tokenwise_preds, spans, tokenizer, new_tags):
172
+
173
+ # spanwise annotated text
174
+ annotated = []
175
+ span_ends = -1
176
+ in_span = False
177
+
178
+ out_of_span_tokens = []
179
+ for i in reversed(range(len(tokenwise_preds))):
180
+
181
+ if in_span:
182
+ if i >= span_ends:
183
+ continue
184
+ else:
185
+ in_span = False
186
+
187
+ predicted_class = ""
188
+ style = ""
189
+
190
+ span = None
191
+ for s in spans:
192
+ if s[1] == i+1:
193
+ span = s
194
+
195
+ if tokenwise_preds[i] != 0 and span is not None:
196
+ predicted_class = f"highlight spanhighlight"
197
+ style = f"background-color: {map_value_to_color((tokenwise_preds[i]-1)/(len(new_tags)-1))}"
198
+ if tokenizer.convert_tokens_to_string([token_strings[i]]).startswith(" "):
199
+ annotated.append("Ġ")
200
+
201
+ span_opener = f"Ġ<span class='{predicted_class}' data-tooltip-text='{new_tags[tokenwise_preds[i]]}' style='{style}'>".replace(" ", "Ġ")
202
+ span_end = f"<span class='small-text'>{new_tags[tokenwise_preds[i]]}</span></span>"
203
+ annotated.extend(out_of_span_tokens)
204
+ out_of_span_tokens = []
205
+ span_ends = span[0]
206
+ in_span = True
207
+ annotated.append(span_end)
208
+ annotated.extend([token_strings[x] for x in reversed(range(span[0], span[1]))])
209
+ annotated.append(span_opener)
210
+ else:
211
+ out_of_span_tokens.append(token_strings[i])
212
+
213
+ annotated.extend(out_of_span_tokens)
214
+
215
+ return [x for x in reversed(annotated)]
216
+
217
+ def gen_json(input_text, max_new_tokens):
218
+ streamer = STOKEStreamer(tok, classifier_token, classifier_span)
219
+
220
+ new_tags = label_map
221
+
222
+ inputs = tok([f" {input_text}"], return_tensors="pt").to(model.device)
223
+ generation_kwargs = dict(
224
+ inputs, streamer=streamer, max_new_tokens=max_new_tokens,
225
+ repetition_penalty=1.2, do_sample=False
226
+ )
227
+
228
+ def generate_async():
229
+ model.generate(**generation_kwargs)
230
+
231
+ thread = Thread(target=generate_async)
232
+ thread.start()
233
+
234
+ # Display generated text as it becomes available
235
+ output_text = ""
236
+ text_tokenwise = ""
237
+ text_spans = ""
238
+ removed_spans = ""
239
+ tags = []
240
+ spans = []
241
+ for new_text in streamer:
242
+ if new_text[1] is not None and new_text[2] != ['']:
243
+ text_tokenwise = ""
244
+ output_text = ""
245
+ tags.extend(new_text[1])
246
+ spans.extend(new_text[-1])
247
+
248
+ # Tokenwise Classification
249
+ for tk, pred in zip(new_text[2],tags):
250
+ if pred != 0:
251
+ style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}"
252
+ if tk.startswith(" "):
253
+ text_tokenwise += " "
254
+ text_tokenwise += f"<span class='tooltip highlight' data-tooltip-text='{new_tags[pred]}' style='{style}'>{tk}</span>"
255
+ output_text += tk
256
+ else:
257
+ text_tokenwise += tk
258
+ output_text += tk
259
+
260
+ # Span Classification
261
+ text_spans = ""
262
+ if len(spans) > 0:
263
+ filtered_spans = remove_overlapping_spans(spans)
264
+ text_spans = generate_html_no_overlap(new_text[2], filtered_spans)
265
+ if len(spans) - len(filtered_spans) > 0:
266
+ removed_spans = f"{len(spans) - len(filtered_spans)} span(s) hidden due to overlap."
267
+ else:
268
+ for tk in new_text[2]:
269
+ text_spans += f"{tk}"
270
+
271
+ # Spanwise Classification
272
+ annotated_tokens = generate_html_spanwise(new_text[2], tags, [x for x in filter_spans(spans)[0]], tok, new_tags)
273
+ generated_text_spanwise = tok.convert_tokens_to_string(annotated_tokens).replace("<|endoftext|>", "").replace("<|begin_of_text|>", "")
274
+
275
+ output = f"{css}<br>"
276
+ output += generated_text_spanwise.replace("\n", " ").replace("$", "$") + "\n<br>"
277
+ #output += "<h5>Show tokenwise classification</h5>\n" + text_tokenwise.replace("\n", " ").replace("$", "\\$").replace("<|endoftext|>", "").replace("<|begin_of_text|>", "")
278
+ #output += "</details><details><summary>Show spans</summary>\n" + text_spans.replace("\n", " ").replace("$", "\\$")
279
+ #if removed_spans != "":
280
+ # output += f"<br><br><i>({removed_spans})</i>"
281
+ list_of_spans = [{"name": tok.convert_tokens_to_string(new_text[2][x[0]:x[1]]).strip(), "type": new_tags[tags[x[1]-1]]} for x in filter_spans(spans)[0] if new_tags[tags[x[1]-1]] != "O"]
282
+
283
+ out_dict = {"text": output_text.replace("<|endoftext|>", "").replace("<|begin_of_text|>", "".strip()), "entites": list_of_spans}
284
+
285
+ yield out_dict
286
+ return
287
+
288
+ # Creating the Gradio Interface
289
+ def generate_text(input_text, messages=None):
290
+ if input_text == "":
291
+ yield "Please enter some text first."
292
+ return
293
+
294
+ token_limit=250
295
+ #print([clean_html(x["content"]) for x in messages])
296
+
297
+ streamer = STOKEStreamer(tok, classifier_token, classifier_span)
298
+
299
+ new_tags = label_map
300
+
301
+ if messages is None:
302
+ messages = []
303
+ else:
304
+ messages = []
305
+ system="""You are a knowledge assistant. Keep your responses very short."""
306
+ messages = [{"role": "system", "content": system}]+ [{"role": x["role"], "content": clean_html(x["content"])} for x in messages] +[{"role": "user", "content": input_text}]
307
+ input_text = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
308
+ inputs = tok([input_text], return_tensors="pt").to(model.device)
309
+
310
+ if len(inputs.input_ids[0]) > 80:
311
+ yield [{"role": "assistant", "content": "Your message is too long for this demo, sorry :("}]
312
+ return
313
+
314
+ #inputs = tok([f" {input_text[:200]}"], return_tensors="pt").to(model.device)
315
+ #inputs = tok([input_text[:200]], return_tensors="pt").to(model.device)
316
+ generation_kwargs = dict(
317
+ inputs, streamer=streamer, max_new_tokens=token_limit-len(inputs.input_ids[0]),
318
+ repetition_penalty=1.2, do_sample=False
319
+ )
320
+
321
+ def generate_async():
322
+ model.generate(**generation_kwargs)
323
+
324
+ thread = Thread(target=generate_async)
325
+ thread.start()
326
+
327
+ # Display generated text as it becomes available
328
+ output_text = ""
329
+ text_tokenwise = ""
330
+ text_spans = ""
331
+ removed_spans = ""
332
+ tags = []
333
+ spans = []
334
+ for new_text in streamer:
335
+ if new_text[1] is not None and new_text[2] != ['']:
336
+ text_tokenwise = ""
337
+ output_text = ""
338
+ tags.extend(new_text[1])
339
+ spans.extend(new_text[-1])
340
+
341
+ # Tokenwise Classification
342
+ for tk, pred in zip(new_text[2],tags):
343
+ if pred != 0:
344
+ style = f"background-color: {map_value_to_color((pred-1)/(len(new_tags)-1))}"
345
+ if tk.startswith(" "):
346
+ text_tokenwise += " "
347
+ text_tokenwise += f"<span class='tooltip highlight' data-tooltip-text='{new_tags[pred]}' style='{style}'>{tk}</span>"
348
+ output_text += tk
349
+ else:
350
+ text_tokenwise += tk
351
+ output_text += tk
352
+
353
+ # Span Classification
354
+ text_spans = ""
355
+ if len(spans) > 0:
356
+ filtered_spans = remove_overlapping_spans(spans)
357
+ text_spans = generate_html_no_overlap(new_text[2], filtered_spans)
358
+ if len(spans) - len(filtered_spans) > 0:
359
+ removed_spans = f"{len(spans) - len(filtered_spans)} span(s) hidden due to overlap."
360
+ else:
361
+ for tk in new_text[2]:
362
+ text_spans += f"{tk}"
363
+
364
+ # Spanwise Classification
365
+ annotated_tokens = generate_html_spanwise(new_text[2], tags, [x for x in filter_spans(spans)[0]], tok, new_tags)
366
+ generated_text_spanwise = tok.convert_tokens_to_string(annotated_tokens).replace("<|endoftext|>", "").replace("<|begin_of_text|>", "")
367
+
368
+ output = generated_text_spanwise
369
+ #output += "<h5>Show tokenwise classification</h5>\n" + text_tokenwise.replace("\n", " ").replace("$", "\\$").replace("<|endoftext|>", "").replace("<|begin_of_text|>", "")
370
+ #output += "</details><details><summary>Show spans</summary>\n" + text_spans.replace("\n", " ").replace("$", "\\$")
371
+ #if removed_spans != "":
372
+ # output += f"<br><br><i>({removed_spans})</i>"
373
+ list_of_spans = [{"name": tok.convert_tokens_to_string(new_text[2][x[0]:x[1]]).strip(), "type": new_tags[tags[x[1]-1]]} for x in filter_spans(spans)[0] if new_tags[tags[x[1]-1]] != "O"]
374
+
375
+ out_dict = {"text": output_text.replace("<|endoftext|>", "").replace("<|begin_of_text|>", "").strip(), "entites": list_of_spans}
376
+
377
+ html_out = output.replace("<|endoftext|>", "").replace("<|begin_of_text|>", "").strip().split("<|end_header_id|>\n\n")[-1].replace("**", "")
378
+
379
+ yield [messages[-1]] + [{"role": "assistant", "content": html_out}]
380
+ return
381
+
382
+ # Load datasets and models for the Gradio app
383
+ datasets = find_datasets_and_model_ids("data/")
384
+ available_models = list(datasets.keys())
385
+ available_datasets = {model: list(datasets[model].keys()) for model in available_models}
386
+ available_configs = {model: {dataset: list(datasets[model][dataset].keys()) for dataset in available_datasets[model]} for model in available_models}
387
+
388
+ def update_datasets(model_name):
389
+ return available_datasets[model_name]
390
+
391
+ def update_configs(model_name, dataset_name):
392
+ return available_configs[model_name][dataset_name]
393
+
394
+ model_id = "meta-llama/Llama-3.2-1B-Instruct"
395
+ data_id = "STOKE_500_wikiqa"
396
+ config_id = "default"
397
+
398
+ #model_id = "gpt2"
399
+ #data_id = "1_NER"
400
+ #config_id = "default"
401
+
402
+ model, tok = get_model_and_tokenizer(model_id)
403
+ if torch.cuda.is_available():
404
+ model.cuda()
405
+
406
+ # Load model classifiers
407
+ try:
408
+ classifier_span, classifier_token, label_map = get_classifiers_for_model(
409
+ model.config.n_head * model.config.n_layer, model.config.n_embd, model.device,
410
+ datasets[model_id][data_id][config_id]
411
+ )
412
+ except:
413
+ classifier_span, classifier_token, label_map = get_classifiers_for_model(
414
+ model.config.num_attention_heads * model.config.num_hidden_layers, model.config.hidden_size, model.device,
415
+ datasets[model_id][data_id][config_id]
416
+ )
417
+
418
+
419
+ css = """
420
+ <style>
421
+ .prose {
422
+ line-height: 200%;
423
+ }
424
+ .highlight {
425
+ display: inline;
426
+ }
427
+ .highlight::after {
428
+ background-color: var(data-color);
429
+ }
430
+ .spanhighlight {
431
+ padding: 2px 5px;
432
+ border-radius: 5px;
433
+ }
434
+ .tooltip {
435
+ position: relative;
436
+ display: inline-block;
437
+ }
438
+
439
+ .tooltip::after {
440
+ content: attr(data-tooltip-text); /* Set content from data-tooltip-text attribute */
441
+ display: none;
442
+ position: absolute;
443
+ background-color: #333;
444
+ color: #fff;
445
+ padding: 5px;
446
+ border-radius: 5px;
447
+ bottom: 100%; /* Position it above the element */
448
+ left: 50%;
449
+ transform: translateX(-50%);
450
+ width: auto;
451
+ min-width: 120px;
452
+ margin: 0 auto;
453
+ text-align: center;
454
+ }
455
+
456
+ .tooltip:hover::after {
457
+ display: block; /* Show the tooltip on hover */
458
+ }
459
+
460
+ .small-text {
461
+ padding: 2px 5px;
462
+ background-color: white;
463
+ border-radius: 5px;
464
+ font-size: xx-small;
465
+ margin-left: 0.5em;
466
+ vertical-align: 0.2em;
467
+ font-weight: bold;
468
+ color: grey!important;
469
+ }
470
+ footer {
471
+ display:none !important
472
+ }
473
+ .gradio-container {
474
+ padding: 0!important;
475
+ height:400px;
476
+ }
477
+ </style>"""
478
+ """
479
+ with gr.Blocks(css=css, elem_id="chatbox") as demo:
480
+ gr.ChatInterface(generate_text, examples=["Who where the Beatles?", "Whats the GDP of Norway?", "List some fun things to do in Miami", "What do you know about the KIT in Karlsruhe?", "Give me a list of the most iconic 90s songs", "Whats the typical cost of a pizza in New York City?", "Got any suggestions for a day trip from Miami?", "Tell me about the climate in Europe.", "Where can I go scuba diving?", "give me a list of famous people and their years of birth"], type="messages")
481
+ """
482
+
483
+ example_messages=[{'role': 'user', 'content': 'Who where the Beatles?'}, {'role': 'assistant', 'content': "The <span class='highlight spanhighlight' data-tooltip-text='ORG' style='background-color: #756bb188'> Beatles<span class='small-text'>ORG</span></span> were a <span class='highlight spanhighlight' data-tooltip-text='NORP' style='background-color: #a1d99b88'> British<span class='small-text'>NORP</span></span> rock band formed in <span class='highlight spanhighlight' data-tooltip-text='GPE' style='background-color: #e6550d88'> Liverpool<span class='small-text'>GPE</span></span>, <span class='highlight spanhighlight' data-tooltip-text='GPE' style='background-color: #e6550d88'> England<span class='small-text'>GPE</span></span> in <span class='highlight spanhighlight' data-tooltip-text='DATE' style='background-color: #6baed688'>1960<span class='small-text'>DATE</span></span> that rose to fame with their music and iconic style during the late <span class='highlight spanhighlight' data-tooltip-text='DATE' style='background-color: #6baed688'>1950s<span class='small-text'>DATE</span></span> and <span class='highlight spanhighlight' data-tooltip-text='DATE' style='background-color: #6baed688'> early 1960s<span class='small-text'>DATE</span></span>. The group consisted of <span class='highlight spanhighlight' data-tooltip-text='PERSON' style='background-color: #bcbddc88'> John Lennon<span class='small-text'>PERSON</span></span> ( <span class='highlight spanhighlight' data-tooltip-text='PERSON' style='background-color: #bcbddc88'>Ringo Starr<span class='small-text'>PERSON</span></span>), <span class='highlight spanhighlight' data-tooltip-text='PERSON' style='background-color: #bcbddc88'> Paul McCartney<span class='small-text'>PERSON</span></span>, <span class='highlight spanhighlight' data-tooltip-text='PERSON' style='background-color: #bcbddc88'> George Harrison<span class='small-text'>PERSON</span></span>, and <span class='highlight spanhighlight' data-tooltip-text='PERSON' style='background-color: #bcbddc88'> Ringo Starr<span class='small-text'>PERSON</span></span>. They're widely regarded as one of the most influential and successful bands in popular culture history."}]
484
+
485
+ with gr.Blocks(css=css, fill_width=True) as demo:
486
+ chatbot = gr.Chatbot(type="messages", value=example_messages)
487
+ msg = gr.Textbox(submit_btn=True)
488
+ msg.submit(lambda: None, None, chatbot).then(generate_text, msg, chatbot, queue="queue")
489
+ # Add an examples section for users to pick from predefined messages
490
+ examples = gr.Examples(examples=["Who where the Beatles?", "Whats the GDP of Norway?", "List some fun things to do in Miami", "What do you know about the KIT in Karlsruhe?"], inputs=msg, run_on_click=True, fn=generate_text, outputs=chatbot)
491
+
492
+
493
+
494
+ demo.launch(server_name="0.0.0.0", server_port=7861)
495
+