Gabriele Tuccio commited on
Commit
210f2f4
·
1 Parent(s): bd1ac35
Files changed (2) hide show
  1. .gitignore +4 -0
  2. app.py +459 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ temp/
2
+ *.zip
3
+ *.csv
4
+ /.gradio/
app.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from grammarllm.scripts.grammar_generation import generate_non_terminals, generate_grammar
2
+ from grammarllm.scripts.map_terminal_tokens import generate_token_maps
3
+ from grammarllm.scripts.table_parsing import parsing_table
4
+
5
+ from grammarllm.modules.BaseStreamer import BaseStreamer
6
+ from grammarllm.modules.PushdownAutomaton import PushdownAutomaton
7
+ from grammarllm.modules.SimpleLogitProcessor import MaskLogitsProcessor
8
+
9
+ import logging
10
+ import re
11
+ import os
12
+ from collections import defaultdict
13
+ from tqdm import tqdm
14
+
15
+ from grammarllm.utils.common_regex import regex_dict
16
+ from grammarllm.utils.examples import *
17
+
18
+ from transformers import AutoTokenizer, AutoModelForCausalLM
19
+ import gradio as gr
20
+ import json
21
+ import zipfile
22
+
23
+ def pipeline(words, tokenizer, lhs, count=0, non_terminals=None, FINAL_RULES=None): #questa è + un preprocessing di ogni produzione nella rules
24
+ """
25
+ Process input words to generate context-free grammar rules.
26
+
27
+ This function implements a pipeline for creating grammar rules from a set of words
28
+ or phrases. It processes the input through several stages: tokenization, state
29
+ transition building, prefix grouping, non-terminal generation, and grammar rule creation.
30
+ The generated rules are added to a master set of rules.
31
+
32
+ Args:
33
+ words (list): Collection of words or phrases to process.
34
+ tokenizer: Tokenizer object used to convert words into tokens.
35
+ lhs (str): Left-hand side symbol for grammar rules.
36
+ count (int, optional): Counter for unique non-terminal generation. Defaults to 0, used to handle apices in NT rules.
37
+ non_terminals (list, optional): Predefined non-terminals to use.
38
+ FINAL_RULES (dict, optional): Existing grammar rules to extend.
39
+
40
+ Returns:
41
+ tuple: A tuple containing:
42
+ - FINAL_RULES (dict): Updated dictionary of grammar rules.
43
+ - count (int): Updated counter value for non-terminal generation.
44
+
45
+ Dependencies:
46
+ - build_SState: Creates state transitions from input words
47
+ - group_by_prefix: Groups transitions by their prefixes
48
+ - generate_non_terminals: Creates non-terminal symbols
49
+ - generate_grammar: Generates grammar rules
50
+ """
51
+ def build_SState(classes, tokenizer):
52
+ SState = []
53
+ tokenized_classes = [tokenizer.tokenize(c) for c in classes]
54
+
55
+ glob_count = 1
56
+ pbar = tqdm(total=len(classes), desc="Build state")
57
+
58
+ for tok_class in tokenized_classes:
59
+ state = 0
60
+ for token in tok_class:
61
+ if token not in SState: #provare a togliere questo if se non necessario!
62
+ SState.append((state,token,glob_count))
63
+ glob_count += 1
64
+ state += 1
65
+ pbar.update(1)
66
+
67
+ pbar.close()
68
+ logging.info(SState)
69
+ #print(list(SState))
70
+ return SState
71
+
72
+ def group_by_prefix(transitions):
73
+ """Group transitions by their state and prefix"""
74
+ grammar = defaultdict(list)
75
+
76
+ # Build transition map
77
+ for state, symbol, end in transitions:
78
+ grammar[state].append((symbol, end))
79
+
80
+ # Group by state and prefix
81
+ grouped = defaultdict(lambda: defaultdict(list))
82
+ for state, transitions_list in grammar.items():
83
+ for symbol, end in transitions_list:
84
+ grouped[state][symbol].append((symbol, end))
85
+
86
+ return grouped
87
+
88
+ tansitions = build_SState(words, tokenizer)
89
+ grouped_data = group_by_prefix(tansitions)
90
+
91
+ #Generate non-terminals
92
+ G,S = generate_non_terminals(grouped_data,count=count)
93
+ count+=1 #aggiunto x la question degli apici
94
+
95
+ #tokenizer.eos_token
96
+ grammar_rules = generate_grammar(G, S, NT=lhs, eos_symbol='|eot|', non_terminals_list=non_terminals)
97
+
98
+
99
+ for key, values in grammar_rules.items():
100
+ if key in FINAL_RULES:
101
+ FINAL_RULES[key].extend(values)
102
+ else:
103
+ FINAL_RULES[key] = values
104
+
105
+
106
+ logging.info("\nGrouped Data:")
107
+ for state, prefixes in grouped_data.items():
108
+ logging.info(f"State {state}:")
109
+ for prefix, class_labels_list in prefixes.items():
110
+ logging.info(f" {prefix} -> {class_labels_list}")
111
+
112
+ logging.info("\n Generated Non-Terminals:\n")
113
+ for nt, prefix in G.items():
114
+ logging.info(f"{nt} -> {prefix}")
115
+
116
+ logging.info("\n Ends Non-Terminals:\n")
117
+ for nt, prefix in S.items():
118
+ logging.info(f"{nt} -> {prefix}")
119
+
120
+
121
+ logging.info("\nGrammar Rules:\n")
122
+ for nt, rules in grammar_rules.items():
123
+ for rule in rules:
124
+ logging.info(f"{rule}")
125
+
126
+ return FINAL_RULES,count
127
+
128
+ def process_grammar_rules(productions, tokenizer):# forse è + una pipeline che poi porta alla final_rueles, infatti chiama la pipeline_for_general
129
+
130
+ """
131
+ Process grammar production rules based on the specified task.
132
+
133
+ This function iterates through production rules and handles them differently
134
+ based on whether the task is 'Classification'/'VR' or 'General'. For general tasks,
135
+ it separates rules with None tags for direct assignment and processes the rest.
136
+
137
+ Args:
138
+ productions (dict): Dictionary of grammar production rules
139
+ tokenizer: Tokenizer to use for processing
140
+
141
+ Returns:
142
+ dict: Final grammar rules
143
+ """
144
+ def extract_tags_and_others(rhs_list):
145
+
146
+ tags_list = []
147
+ others_list = []
148
+ tag_pattern = re.compile(r'<<(.+?)>>')
149
+
150
+ def smart_split(item):
151
+ # Trova tutti i tag <<...>> e separa il resto del testo
152
+ matches = list(tag_pattern.finditer(item))
153
+ parts = []
154
+ last_index = 0
155
+
156
+ for match in matches:
157
+ # Aggiungi il testo prima del tag, splittato
158
+ pre_text = item[last_index:match.start()]
159
+ parts.extend(pre_text.strip().split())
160
+
161
+ # Aggiungi il tag intero come una sola unità
162
+ parts.append(match.group(0))
163
+ last_index = match.end()
164
+
165
+ # Aggiungi eventuale testo dopo l'ultimo tag
166
+ post_text = item[last_index:]
167
+ parts.extend(post_text.strip().split())
168
+
169
+ return parts
170
+
171
+ for item in rhs_list:
172
+ tags = []
173
+ others = []
174
+ if re.search(tag_pattern, item):
175
+ words = smart_split(item)
176
+ current_chunk = []
177
+ for word in words:
178
+ match = re.fullmatch(tag_pattern, word)
179
+ if match:
180
+ tags.append(match.group(1)) # salva solo il contenuto del tag
181
+ else:
182
+ current_chunk.append(word)
183
+
184
+ if current_chunk:
185
+ others.append(' '.join(current_chunk))
186
+ else:
187
+ others.append(None)
188
+
189
+ tags_list.append(tags)
190
+ others_list.append(others)
191
+ else:
192
+ tags_list.append([None])
193
+ others_list.append([item])
194
+
195
+ return tags_list, others_list
196
+
197
+ final_rules = {}
198
+ count = 0
199
+
200
+ for lhs, rhs_list in productions.items():
201
+
202
+
203
+ tags_list, non_terminals_list = extract_tags_and_others(rhs_list)
204
+
205
+ filtered_tags = []
206
+ filtered_non_terminals = []
207
+ for j in range(len(tags_list)):
208
+ tag_group = tags_list[j]
209
+ non_terminal_group = non_terminals_list[j]
210
+
211
+
212
+ if any(tag is not None for tag in tag_group):
213
+ # Filter out None tags and add them directly to final_rules
214
+ i = 0
215
+
216
+ while i < len(tag_group):
217
+ if tag_group[i] is None:
218
+ # Add rule directly to final_rules
219
+ if lhs in final_rules:
220
+ final_rules[lhs].append(rhs_list[i])
221
+ else:
222
+ final_rules[lhs] = [rhs_list[i]]
223
+
224
+ # Remove processed tag and non-terminal
225
+ tag_group.pop(i)
226
+ non_terminal_group.pop(i)
227
+ else:
228
+ # Keep tag and non-terminal for further processing
229
+ filtered_tags.append(tag_group[i])
230
+ if i < len(non_terminal_group):
231
+ filtered_non_terminals.append(non_terminal_group[i])
232
+ i += 1
233
+ else:
234
+ # All tags are None, add rules directly
235
+ final_rules.update({lhs: rhs_list})
236
+
237
+ #print(f"Filtered tags: {filtered_tags}") #DEBUG
238
+ #print(f"Filtered non-terminals: {filtered_non_terminals}")#DEBUG
239
+
240
+ # Process remaining tags through the general pipeline
241
+ if filtered_tags:
242
+ final_rules, count = pipeline(
243
+ filtered_tags, tokenizer, lhs,
244
+ count=count,
245
+ non_terminals=filtered_non_terminals,
246
+ FINAL_RULES=final_rules
247
+ )
248
+
249
+
250
+ return final_rules, count
251
+
252
+ def get_parsing_table_and_map_tt(tokenizer, productions=None, regex_dict=None):
253
+ def write_grammar_to_file(grammar_rules):
254
+ output_file = os.path.join('temp','grammar_rules.txt')
255
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
256
+ """Write grammar rules to a file"""
257
+ with open(output_file, 'w') as f:
258
+ for non_terminal, rules in grammar_rules.items():
259
+ for rule in rules:
260
+ f.write(f"{non_terminal} -> {rule}\n")
261
+ f.write("\n")
262
+ logging.info(f"\nGrammar Rules to {output_file}")
263
+
264
+ # Get final grammar rules
265
+ final_rules, _ = process_grammar_rules(productions, tokenizer)
266
+
267
+ #print(final_rules) #DEBUG
268
+ write_grammar_to_file(final_rules)
269
+ logging.info(final_rules)
270
+
271
+ # Generate parsing table
272
+ pars_tab = parsing_table(final_rules)
273
+
274
+ # Generate token maps
275
+ if regex_dict:
276
+ map_terminal_tokens = generate_token_maps(tokenizer, pars_tab, regex_dict)
277
+ else:
278
+ map_terminal_tokens = generate_token_maps(tokenizer, pars_tab)
279
+
280
+ logging.info("\nMap Terminal Tokens:\n")
281
+ for key, values in map_terminal_tokens.items():
282
+ logging.info(f"{key} -> {values}")
283
+
284
+ return pars_tab, map_terminal_tokens
285
+
286
+ def generate_grammar_parameters(tokenizer, pars_tab, map_terminal_tokens):
287
+ # Create Pushdown Automaton and initialize processors and streamer
288
+ pda = PushdownAutomaton(grammar=pars_tab, startSymbol='S*', map=map_terminal_tokens)
289
+ return MaskLogitsProcessor(tokenizer, pda), BaseStreamer(tokenizer, pda)
290
+
291
+ def setup_logging():
292
+ """Setup logging configuration."""
293
+ log_dir = 'temp'
294
+ os.makedirs(log_dir, exist_ok=True) # Ensure the log directory exists
295
+
296
+ logging.basicConfig(
297
+ filename=os.path.join(log_dir, 'GRAM-GEN.log'),
298
+ level=logging.INFO,
299
+ format='%(asctime)s - %(levelname)s - %(message)s',
300
+ filemode='w+' # Overwrites the file every time
301
+ )
302
+
303
+ def generate_text(model, tokenizer, text, logit_processor, streamer, max_new_tokens=400, do_sample=False, temperature=None, top_p=None, **kwargs):
304
+ """
305
+ Genera testo vincolato dalla grammatica, con configurazione dei parametri di generazione sicura.
306
+
307
+ Args:
308
+ model: Il modello pre-addestrato.
309
+ tokenizer: Il tokenizer del modello.
310
+ text: Input text iniziale.
311
+ logit_processor: Processor dei logit basato sulla grammatica.
312
+ streamer: Streamer per l'output live.
313
+ max_new_tokens: Numero massimo di nuovi token da generare.
314
+ do_sample: Se True, abilita la generazione stocastica.
315
+ temperature: Controlla la casualità (usato solo se do_sample=True).
316
+ top_p: Top-p (nucleus sampling), usato solo se do_sample=True.
317
+ **kwargs: Parametri aggiuntivi opzionali per model.generate().
318
+ """
319
+
320
+ try:
321
+ tokenized_input = tokenizer(text, return_tensors="pt")
322
+
323
+ # Safe defaults
324
+ kwargs.setdefault("num_beams", 1) # beam search disattivato
325
+ kwargs.setdefault("pad_token_id", tokenizer.eos_token_id)
326
+
327
+ # Sicurezza num_beams
328
+ if kwargs["num_beams"] != 1:
329
+ logging.warning("⚠️ num_beams > 1 non è compatibile con la generazione vincolata da grammatica. Impostato automaticamente a num_beams=1.")
330
+ kwargs["num_beams"] = 1
331
+
332
+ # Sampling parameters
333
+ if do_sample:
334
+ if temperature is not None:
335
+ kwargs["temperature"] = temperature
336
+ if top_p is not None:
337
+ kwargs["top_p"] = top_p
338
+ else:
339
+ # Rimuovi parametri di sampling se presenti
340
+ kwargs.pop("temperature", None)
341
+ kwargs.pop("top_p", None)
342
+
343
+ # Device compatibility
344
+ device = model.device
345
+ input_ids = tokenized_input["input_ids"].to(device)
346
+ if input_ids.device != model.device:
347
+ logging.warning("Errore: gli 'input_ids' sono sulla device {input_ids.device}, mentre il modello è sulla device {model.device}. Spostando 'input_ids' sulla stessa device del modello.")
348
+
349
+ attention_mask = tokenized_input["attention_mask"].to(device)
350
+ if attention_mask.device != model.device:
351
+ logging.warning(f"Errore: l'attention_mask è sulla device {attention_mask.device}, mentre il modello è sulla device {model.device}. Spostando 'attention_mask' sulla stessa device del modello.")
352
+
353
+
354
+ start = input_ids.shape[1]
355
+
356
+ output = model.generate(
357
+ input_ids=input_ids,
358
+ attention_mask=attention_mask,
359
+ do_sample=do_sample,
360
+ max_new_tokens=max_new_tokens,
361
+ streamer=streamer,
362
+ logits_processor=[logit_processor],
363
+ **kwargs
364
+ )
365
+
366
+ answer = tokenizer.decode(output[0][start:], skip_special_tokens=True)
367
+
368
+ return answer
369
+
370
+ except Exception as e:
371
+ raise RuntimeError(f"Errore nella generazione del testo: {e}")
372
+
373
+
374
+
375
+ def run_grammarllm(prompt, productions_json, regex_json):
376
+ setup_logging()
377
+
378
+ # Parsing productions
379
+ try:
380
+ productions = json.loads(productions_json)
381
+ except json.JSONDecodeError:
382
+ return "Errore: JSON productions non valido.", None
383
+
384
+ # Parsing regex_dict
385
+ try:
386
+ regex_raw = json.loads(regex_json)
387
+ regex_dict = {key: re.compile(pattern) for key, pattern in regex_raw.items()}
388
+ except json.JSONDecodeError:
389
+ return "Errore: JSON regex non valido.", None
390
+ except re.error as e:
391
+ return f"Errore nella compilazione regex: {str(e)}", None
392
+
393
+ try:
394
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
395
+ model = AutoModelForCausalLM.from_pretrained("gpt2")
396
+
397
+ pars_table, map_terminal_tokens = get_parsing_table_and_map_tt(
398
+ tokenizer,
399
+ productions=productions,
400
+ regex_dict=regex_dict,
401
+ )
402
+
403
+ LogitProcessor, Streamer = generate_grammar_parameters(tokenizer, pars_table, map_terminal_tokens)
404
+ output = generate_text(model, tokenizer, prompt, LogitProcessor, Streamer)
405
+
406
+ temp_dir = "./temp"
407
+ zip_path = temp_dir + ".zip"
408
+
409
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
410
+ for root, dirs, files in os.walk(temp_dir):
411
+ for file in files:
412
+ file_path = os.path.join(root, file)
413
+ arcname = os.path.relpath(file_path, temp_dir)
414
+ zipf.write(file_path, arcname)
415
+
416
+ return output, zip_path
417
+
418
+ except Exception as e:
419
+ return f"Errore durante l'inferenza: {str(e)}", None
420
+
421
+ # Input di esempio per regex_json (stringa JSON)
422
+ default_regex_json = json.dumps({
423
+ "regex_alfanum": "[a-zA-Z0-9]+",
424
+ "regex_letters": "[a-zA-Z]+",
425
+ "regex_number": "\\d+",
426
+ "regex_decimal": "\\d+([.,]\\d+)?",
427
+ "regex_var": "[a-zA-Z_][a-zA-Z0-9_]*",
428
+ "regex_)": "\\)",
429
+ "regex_(": "\\("
430
+ }, indent=4)
431
+
432
+ default_grammar_json = json.dumps({
433
+ "S*": ["<<positive>> A", "<<negative>> B", "<<neutral>> C"],
434
+ "A": ["<<happy>> D", "<<peaceful>> E", "<<joyful>> F"],
435
+ "B": ["<<sad>>", "<<angry>>", "<<frustrated>>"],
436
+ "C": ["<<calm>>", "<<indifferent>>", "<<unemotional>>"],
437
+ "D": ["<<enthusiastic>>"],
438
+ "E": ["<<content>>"],
439
+ "F": ["<<excited>>"]
440
+ }, indent=4)
441
+
442
+
443
+ demo = gr.Interface(
444
+ fn=run_grammarllm,
445
+ inputs=[
446
+ gr.Textbox(label="Inserisci prompt testuale"),
447
+ gr.Textbox(label="Inserisci productions (JSON)", lines=10, value=default_grammar_json),
448
+ gr.Textbox(label="Inserisci regex_dict (JSON)", lines=10, value=default_regex_json),
449
+ ],
450
+ outputs=[
451
+ gr.Textbox(label="Output generato"),
452
+ gr.File(label="Scarica ZIP"),
453
+ ],
454
+ title="GrammarLLM con output e download ZIP",
455
+ description="Inserisci prompt, productions e regex per generare testo e scaricare i file.",
456
+ )
457
+
458
+ if __name__ == "__main__":
459
+ demo.launch(debug=True)