Spaces:
Sleeping
Sleeping
| from grammarllm.scripts.grammar_generation import generate_non_terminals, generate_grammar | |
| from grammarllm.scripts.map_terminal_tokens import generate_token_maps | |
| from grammarllm.scripts.table_parsing import parsing_table | |
| from grammarllm.modules.BaseStreamer import BaseStreamer | |
| from grammarllm.modules.PushdownAutomaton import PushdownAutomaton | |
| from grammarllm.modules.SimpleLogitProcessor import MaskLogitsProcessor | |
| import logging | |
| import re | |
| import os | |
| from collections import defaultdict | |
| from tqdm import tqdm | |
| from grammarllm.utils.common_regex import regex_dict | |
| from grammarllm.utils.examples import * | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import gradio as gr | |
| import json | |
| import zipfile | |
| import spaces | |
| import torch | |
| from huggingface_hub import login | |
| login(token = os.getenv("llama_acces_token")) | |
| def pipeline(words, tokenizer, lhs, count=0, non_terminals=None, FINAL_RULES=None): #questa Γ¨ + un preprocessing di ogni produzione nella rules | |
| """ | |
| Process input words to generate context-free grammar rules. | |
| This function implements a pipeline for creating grammar rules from a set of words | |
| or phrases. It processes the input through several stages: tokenization, state | |
| transition building, prefix grouping, non-terminal generation, and grammar rule creation. | |
| The generated rules are added to a master set of rules. | |
| Args: | |
| words (list): Collection of words or phrases to process. | |
| tokenizer: Tokenizer object used to convert words into tokens. | |
| lhs (str): Left-hand side symbol for grammar rules. | |
| count (int, optional): Counter for unique non-terminal generation. Defaults to 0, used to handle apices in NT rules. | |
| non_terminals (list, optional): Predefined non-terminals to use. | |
| FINAL_RULES (dict, optional): Existing grammar rules to extend. | |
| Returns: | |
| tuple: A tuple containing: | |
| - FINAL_RULES (dict): Updated dictionary of grammar rules. | |
| - count (int): Updated counter value for non-terminal generation. | |
| Dependencies: | |
| - build_SState: Creates state transitions from input words | |
| - group_by_prefix: Groups transitions by their prefixes | |
| - generate_non_terminals: Creates non-terminal symbols | |
| - generate_grammar: Generates grammar rules | |
| """ | |
| def build_SState(classes, tokenizer): | |
| SState = [] | |
| tokenized_classes = [tokenizer.tokenize(c) for c in classes] | |
| glob_count = 1 | |
| pbar = tqdm(total=len(classes), desc="Build state") | |
| for tok_class in tokenized_classes: | |
| state = 0 | |
| for token in tok_class: | |
| if token not in SState: #provare a togliere questo if se non necessario! | |
| SState.append((state,token,glob_count)) | |
| glob_count += 1 | |
| state += 1 | |
| pbar.update(1) | |
| pbar.close() | |
| logging.info(SState) | |
| #print(list(SState)) | |
| return SState | |
| def group_by_prefix(transitions): | |
| """Group transitions by their state and prefix""" | |
| grammar = defaultdict(list) | |
| # Build transition map | |
| for state, symbol, end in transitions: | |
| grammar[state].append((symbol, end)) | |
| # Group by state and prefix | |
| grouped = defaultdict(lambda: defaultdict(list)) | |
| for state, transitions_list in grammar.items(): | |
| for symbol, end in transitions_list: | |
| grouped[state][symbol].append((symbol, end)) | |
| return grouped | |
| tansitions = build_SState(words, tokenizer) | |
| grouped_data = group_by_prefix(tansitions) | |
| #Generate non-terminals | |
| G,S = generate_non_terminals(grouped_data,count=count) | |
| count+=1 #aggiunto x la question degli apici | |
| #tokenizer.eos_token | |
| grammar_rules = generate_grammar(G, S, NT=lhs, eos_symbol='|eot|', non_terminals_list=non_terminals) | |
| for key, values in grammar_rules.items(): | |
| if key in FINAL_RULES: | |
| FINAL_RULES[key].extend(values) | |
| else: | |
| FINAL_RULES[key] = values | |
| logging.info("\nGrouped Data:") | |
| for state, prefixes in grouped_data.items(): | |
| logging.info(f"State {state}:") | |
| for prefix, class_labels_list in prefixes.items(): | |
| logging.info(f" {prefix} -> {class_labels_list}") | |
| logging.info("\n Generated Non-Terminals:\n") | |
| for nt, prefix in G.items(): | |
| logging.info(f"{nt} -> {prefix}") | |
| logging.info("\n Ends Non-Terminals:\n") | |
| for nt, prefix in S.items(): | |
| logging.info(f"{nt} -> {prefix}") | |
| logging.info("\nGrammar Rules:\n") | |
| for nt, rules in grammar_rules.items(): | |
| for rule in rules: | |
| logging.info(f"{rule}") | |
| return FINAL_RULES,count | |
| def process_grammar_rules(productions, tokenizer):# forse Γ¨ + una pipeline che poi porta alla final_rueles, infatti chiama la pipeline_for_general | |
| """ | |
| Process grammar production rules based on the specified task. | |
| This function iterates through production rules and handles them differently | |
| based on whether the task is 'Classification'/'VR' or 'General'. For general tasks, | |
| it separates rules with None tags for direct assignment and processes the rest. | |
| Args: | |
| productions (dict): Dictionary of grammar production rules | |
| tokenizer: Tokenizer to use for processing | |
| Returns: | |
| dict: Final grammar rules | |
| """ | |
| def extract_tags_and_others(rhs_list): | |
| tags_list = [] | |
| others_list = [] | |
| tag_pattern = re.compile(r'<<(.+?)>>') | |
| def smart_split(item): | |
| # Trova tutti i tag <<...>> e separa il resto del testo | |
| matches = list(tag_pattern.finditer(item)) | |
| parts = [] | |
| last_index = 0 | |
| for match in matches: | |
| # Aggiungi il testo prima del tag, splittato | |
| pre_text = item[last_index:match.start()] | |
| parts.extend(pre_text.strip().split()) | |
| # Aggiungi il tag intero come una sola unitΓ | |
| parts.append(match.group(0)) | |
| last_index = match.end() | |
| # Aggiungi eventuale testo dopo l'ultimo tag | |
| post_text = item[last_index:] | |
| parts.extend(post_text.strip().split()) | |
| return parts | |
| for item in rhs_list: | |
| tags = [] | |
| others = [] | |
| if re.search(tag_pattern, item): | |
| words = smart_split(item) | |
| current_chunk = [] | |
| for word in words: | |
| match = re.fullmatch(tag_pattern, word) | |
| if match: | |
| tags.append(match.group(1)) # salva solo il contenuto del tag | |
| else: | |
| current_chunk.append(word) | |
| if current_chunk: | |
| others.append(' '.join(current_chunk)) | |
| else: | |
| others.append(None) | |
| tags_list.append(tags) | |
| others_list.append(others) | |
| else: | |
| tags_list.append([None]) | |
| others_list.append([item]) | |
| return tags_list, others_list | |
| final_rules = {} | |
| count = 0 | |
| for lhs, rhs_list in productions.items(): | |
| tags_list, non_terminals_list = extract_tags_and_others(rhs_list) | |
| filtered_tags = [] | |
| filtered_non_terminals = [] | |
| for j in range(len(tags_list)): | |
| tag_group = tags_list[j] | |
| non_terminal_group = non_terminals_list[j] | |
| if any(tag is not None for tag in tag_group): | |
| # Filter out None tags and add them directly to final_rules | |
| i = 0 | |
| while i < len(tag_group): | |
| if tag_group[i] is None: | |
| # Add rule directly to final_rules | |
| if lhs in final_rules: | |
| final_rules[lhs].append(rhs_list[i]) | |
| else: | |
| final_rules[lhs] = [rhs_list[i]] | |
| # Remove processed tag and non-terminal | |
| tag_group.pop(i) | |
| non_terminal_group.pop(i) | |
| else: | |
| # Keep tag and non-terminal for further processing | |
| filtered_tags.append(tag_group[i]) | |
| if i < len(non_terminal_group): | |
| filtered_non_terminals.append(non_terminal_group[i]) | |
| i += 1 | |
| else: | |
| # All tags are None, add rules directly | |
| final_rules.update({lhs: rhs_list}) | |
| #print(f"Filtered tags: {filtered_tags}") #DEBUG | |
| #print(f"Filtered non-terminals: {filtered_non_terminals}")#DEBUG | |
| # Process remaining tags through the general pipeline | |
| if filtered_tags: | |
| final_rules, count = pipeline( | |
| filtered_tags, tokenizer, lhs, | |
| count=count, | |
| non_terminals=filtered_non_terminals, | |
| FINAL_RULES=final_rules | |
| ) | |
| return final_rules, count | |
| def get_parsing_table_and_map_tt(tokenizer, productions=None, regex_dict=None): | |
| def write_grammar_to_file(grammar_rules): | |
| output_file = os.path.join('temp','grammar_rules.txt') | |
| os.makedirs(os.path.dirname(output_file), exist_ok=True) | |
| """Write grammar rules to a file""" | |
| with open(output_file, 'w') as f: | |
| for non_terminal, rules in grammar_rules.items(): | |
| for rule in rules: | |
| f.write(f"{non_terminal} -> {rule}\n") | |
| f.write("\n") | |
| logging.info(f"\nGrammar Rules to {output_file}") | |
| # Get final grammar rules | |
| final_rules, _ = process_grammar_rules(productions, tokenizer) | |
| #print(final_rules) #DEBUG | |
| write_grammar_to_file(final_rules) | |
| logging.info(final_rules) | |
| # Generate parsing table | |
| pars_tab = parsing_table(final_rules) | |
| # Generate token maps | |
| if regex_dict: | |
| map_terminal_tokens = generate_token_maps(tokenizer, pars_tab, regex_dict) | |
| else: | |
| map_terminal_tokens = generate_token_maps(tokenizer, pars_tab) | |
| logging.info("\nMap Terminal Tokens:\n") | |
| for key, values in map_terminal_tokens.items(): | |
| logging.info(f"{key} -> {values}") | |
| return pars_tab, map_terminal_tokens | |
| def generate_grammar_parameters(tokenizer, pars_tab, map_terminal_tokens): | |
| # Create Pushdown Automaton and initialize processors and streamer | |
| pda = PushdownAutomaton(grammar=pars_tab, startSymbol='S*', map=map_terminal_tokens) | |
| return MaskLogitsProcessor(tokenizer, pda), BaseStreamer(tokenizer, pda) | |
| def setup_logging(): | |
| """Setup logging configuration.""" | |
| log_dir = 'temp' | |
| os.makedirs(log_dir, exist_ok=True) # Ensure the log directory exists | |
| logging.basicConfig( | |
| filename=os.path.join(log_dir, 'GRAM-GEN.log'), | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| filemode='w+' # Overwrites the file every time | |
| ) | |
| def generate_text(model, tokenizer, text, logit_processor, streamer, max_new_tokens=400, do_sample=False, temperature=None, top_p=None, **kwargs): | |
| """ | |
| Genera testo vincolato dalla grammatica, con configurazione dei parametri di generazione sicura. | |
| Args: | |
| model: Il modello pre-addestrato. | |
| tokenizer: Il tokenizer del modello. | |
| text: Input text iniziale. | |
| logit_processor: Processor dei logit basato sulla grammatica. | |
| streamer: Streamer per l'output live. | |
| max_new_tokens: Numero massimo di nuovi token da generare. | |
| do_sample: Se True, abilita la generazione stocastica. | |
| temperature: Controlla la casualitΓ (usato solo se do_sample=True). | |
| top_p: Top-p (nucleus sampling), usato solo se do_sample=True. | |
| **kwargs: Parametri aggiuntivi opzionali per model.generate(). | |
| """ | |
| try: | |
| tokenized_input = tokenizer(text, return_tensors="pt") | |
| # Safe defaults | |
| kwargs.setdefault("num_beams", 1) # beam search disattivato | |
| kwargs.setdefault("pad_token_id", tokenizer.eos_token_id) | |
| # Sicurezza num_beams | |
| if kwargs["num_beams"] != 1: | |
| logging.warning("β οΈ num_beams > 1 non Γ¨ compatibile con la generazione vincolata da grammatica. Impostato automaticamente a num_beams=1.") | |
| kwargs["num_beams"] = 1 | |
| # Sampling parameters | |
| if do_sample: | |
| if temperature is not None: | |
| kwargs["temperature"] = temperature | |
| if top_p is not None: | |
| kwargs["top_p"] = top_p | |
| else: | |
| # Rimuovi parametri di sampling se presenti | |
| kwargs.pop("temperature", None) | |
| kwargs.pop("top_p", None) | |
| # Device compatibility | |
| device = model.device | |
| input_ids = tokenized_input["input_ids"].to(device) | |
| if input_ids.device != model.device: | |
| logging.warning("Errore: gli 'input_ids' sono sulla device {input_ids.device}, mentre il modello Γ¨ sulla device {model.device}. Spostando 'input_ids' sulla stessa device del modello.") | |
| attention_mask = tokenized_input["attention_mask"].to(device) | |
| if attention_mask.device != model.device: | |
| logging.warning(f"Errore: l'attention_mask Γ¨ sulla device {attention_mask.device}, mentre il modello Γ¨ sulla device {model.device}. Spostando 'attention_mask' sulla stessa device del modello.") | |
| start = input_ids.shape[1] | |
| output = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| do_sample=do_sample, | |
| max_new_tokens=max_new_tokens, | |
| streamer=streamer, | |
| logits_processor=[logit_processor], | |
| **kwargs | |
| ) | |
| answer = tokenizer.decode(output[0][start:], skip_special_tokens=True) | |
| return answer | |
| except Exception as e: | |
| raise RuntimeError(f"Errore nella generazione del testo: {e}") | |
| def run_grammarllm(prompt, productions_json, model_choice,regex_json): | |
| setup_logging() | |
| # Parsing productions | |
| try: | |
| productions = json.loads(productions_json) | |
| except json.JSONDecodeError: | |
| return "Errore: JSON productions non valido.", None | |
| # Regex fissa, non caricata dall'utente | |
| regex_raw = { | |
| "regex_alfanum": "[a-zA-Z0-9]+", | |
| "regex_letters": "[a-zA-Z]+", | |
| "regex_number": "\\d+", | |
| "regex_decimal": "\\d+([.,]\\d+)?", | |
| "regex_var": "[a-zA-Z_][a-zA-Z0-9_]*", | |
| "regex_)": "\\)", | |
| "regex_(": "\\(" | |
| } | |
| try: | |
| regex_raw = json.loads(regex_json) | |
| regex_dict = {key: re.compile(pattern) for key, pattern in regex_raw.items()} | |
| except (json.JSONDecodeError, re.error) as e: | |
| return f"Errore nelle regex personalizzate: {str(e)}", None | |
| try: | |
| # Selezione del modello basata sulla scelta dell'utente | |
| if model_choice == "GPT-2": | |
| model_name = "gpt2" | |
| elif model_choice == "Llama 3.2 1B": | |
| model_name = "meta-llama/Llama-3.2-1B-Instruct" | |
| #elif model_choice == "Llama 3.1 8B": | |
| # model_name = "meta-llama/Llama-3.1-8B-Instruct" | |
| else: | |
| return f"Modello non supportato: {model_choice}", None | |
| # Caricamento del tokenizer e del modello | |
| print(f"Caricamento del modello: {model_name}") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Configurazione del device e dtype per ottimizzare le prestazioni | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if model_choice.startswith("Llama"): | |
| # Per i modelli Llama, usa torch_dtype=torch.float16 per risparmiare memoria | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| else: | |
| # Per GPT-2 | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| model = model.to(device) | |
| # Aggiungi pad_token se non esiste | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| pars_table, map_terminal_tokens = get_parsing_table_and_map_tt( | |
| tokenizer, | |
| productions=productions, | |
| regex_dict=regex_dict, | |
| ) | |
| LogitProcessor, Streamer = generate_grammar_parameters(tokenizer, pars_table, map_terminal_tokens) | |
| output = generate_text(model, tokenizer, prompt, LogitProcessor, Streamer) | |
| # Creazione del file ZIP | |
| temp_dir = "./temp" | |
| zip_path = temp_dir + ".zip" | |
| # Assicurati che temp_dir esista | |
| if os.path.exists(temp_dir): | |
| with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: | |
| for root, dirs, files in os.walk(temp_dir): | |
| for file in files: | |
| file_path = os.path.join(root, file) | |
| arcname = os.path.relpath(file_path, temp_dir) | |
| zipf.write(file_path, arcname) | |
| else: | |
| zip_path = None | |
| # Libera la memoria del modello | |
| del model | |
| torch.cuda.empty_cache() if torch.cuda.is_available() else None | |
| return output, zip_path | |
| except Exception as e: | |
| return f"Errore durante l'inferenza: {str(e)}", None | |
| default_grammars = { | |
| "HC Grammar": json.dumps({ | |
| "S*": ["<<positive>> A", "<<negative>> B", "<<neutral>> C"], | |
| "A": ["<<happy>> D", "<<peaceful>> E", "<<joyful>> F"], | |
| "B": ["<<sad>>", "<<angry>>", "<<frustrated>>"], | |
| "C": ["<<calm>>", "<<indifferent>>", "<<unemotional>>"], | |
| "D": ["<<enthusiastic>>"], | |
| "E": ["<<content>>"], | |
| "F": ["<<excited>>"] | |
| }, indent=4), | |
| "VR Grammar": json.dumps({ | |
| "S*": ["<<positive>> S*", "<<negative>> S*", "<<neutral>> S*"], | |
| }, indent=4), | |
| "General Grammar": json.dumps({ | |
| 'S*': ["( LETTERS )"], | |
| 'LETTERS': ['letters number LETTERS',"Ξ΅"] | |
| }, indent=4), | |
| } | |
| def update_productions(grammar_choice): | |
| # Aggiorna textbox productions al cambio preset | |
| return default_grammars[grammar_choice] | |
| def load_file(file_obj): | |
| if file_obj is None: | |
| return "Errore: nessun file caricato." | |
| try: | |
| # In newer Gradio versions, file_obj is a path string, not a file object | |
| if isinstance(file_obj, str): | |
| # file_obj is the file path | |
| with open(file_obj, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| else: | |
| # Fallback for older Gradio versions or different file object types | |
| if hasattr(file_obj, 'name'): | |
| # file_obj has a 'name' attribute containing the path | |
| with open(file_obj.name, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| else: | |
| # Try to read directly (old behavior) | |
| content = file_obj.read().decode("utf-8") | |
| json.loads(content) # controlla che sia JSON valido | |
| return content | |
| except Exception as e: | |
| return f"Errore nel caricamento file: {str(e)}" | |
| # Interfaccia Gradio migliorata | |
| with gr.Blocks(title="GrammarLLM - enable structured generation via formal language") as demo: | |
| gr.Markdown("# GrammarLLM - enable structured generation via LLprefix") | |
| gr.Markdown("") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_input = gr.Textbox( | |
| label="Insert your prompt", | |
| placeholder="Type here your prompt...", | |
| lines=3 | |
| ) | |
| with gr.Column(scale=1): | |
| model_choice = gr.Dropdown( | |
| choices=["GPT-2", "Llama 3.2 1B"],#, "Llama 3.1 8B"], | |
| label="Choose the model", | |
| value="GPT-2", | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| grammar_choice = gr.Dropdown( | |
| list(default_grammars.keys()), | |
| label="Choose Productions (JSON)", | |
| value="HC Grammar", | |
| interactive=True, | |
| elem_id="grammar_choice" | |
| ) | |
| with gr.Column(): | |
| productions_upload = gr.File( | |
| label="Upload file Productions (JSON)", | |
| file_types=['.json'] | |
| ) | |
| productions_text = gr.Textbox( | |
| label="Productions (JSON)", | |
| lines=15, | |
| value=default_grammars["HC Grammar"], | |
| info="Type your here your grammar in json fromat" | |
| ) | |
| regex_text = gr.Textbox( | |
| label="Regex to define Terminals (JSON)", | |
| lines=10, | |
| value=json.dumps({ | |
| "regex_alfanum": "[a-zA-Z0-9]+", | |
| "regex_letters": "[a-zA-Z]+", | |
| "regex_number": "\\d+", | |
| "regex_decimal": "\\d+([.,]\\d+)?", | |
| "regex_var": "[a-zA-Z_][a-zA-Z0-9_]*", | |
| "regex_)": "\\)", | |
| "regex_(": "\\(" | |
| }, indent=4), | |
| info="Modify these common regex" | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("π Generate Output", variant="primary", size="lg") | |
| clear_btn = gr.Button("ποΈ Clean", variant="secondary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| output_text = gr.Textbox( | |
| label="Output generated", | |
| lines=10, | |
| show_copy_button=True | |
| ) | |
| with gr.Column(): | |
| zip_file = gr.File(label="π¦ Download ZIP (if available)") | |
| with gr.Accordion("βΉοΈ About GrammarLLM and LLprefix", open=False): | |
| gr.Markdown(""" | |
| ### π What is GrammarLLM? | |
| GrammarLLM enables structured text generation constrained by a formal grammar, using LLMs (Large Language Models) such as GPT-2 or LLaMA. | |
| ### π What you can do: | |
| - **Hierarchical classification**: Define class hierarchies, as shown in the "HC Grammar" example. | |
| - **Vocabulary restriction**: Specify a limited set of valid words to be used. Including examples in the prompt is highly recommended to improve output quality. | |
| - **Constrained generation**: Use LLprefix to define any regular or context-free grammar in JSON format. | |
| π For more details about LLprefix and the underlying algorithms, refer to the official paper. | |
| """) | |
| # Callback: quando cambio dropdown, aggiorno productions_text | |
| grammar_choice.change( | |
| fn=update_productions, | |
| inputs=grammar_choice, | |
| outputs=productions_text, | |
| ) | |
| # Callback: quando carico file productions, aggiorno productions_text (override dropdown) | |
| productions_upload.upload( | |
| fn=load_file, | |
| inputs=productions_upload, | |
| outputs=productions_text, | |
| ) | |
| # Al submit del form chiamo run_grammarllm | |
| submit_btn.click( | |
| fn=run_grammarllm, | |
| inputs=[prompt_input, productions_text, model_choice, regex_text], | |
| outputs=[output_text, zip_file], | |
| show_progress=True | |
| ) | |
| # Funzione per pulire i campi | |
| def clear_fields(): | |
| return "", default_grammars["HC"], "", None, json.dumps({ | |
| "regex_alfanum": "[a-zA-Z0-9]+", | |
| "regex_letters": "[a-zA-Z]+", | |
| "regex_number": "\\d+", | |
| "regex_decimal": "\\d+([.,]\\d+)?", | |
| "regex_var": "[a-zA-Z_][a-zA-Z0-9_]*", | |
| "regex_)": "\\)", | |
| "regex_(": "\\(" | |
| }, indent=4) | |
| clear_btn.click( | |
| fn=clear_fields, | |
| outputs=[prompt_input, productions_text, output_text, zip_file, regex_text] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |