grammarllm / app.py
Gabriele Tuccio
update
d3289a1
from grammarllm.scripts.grammar_generation import generate_non_terminals, generate_grammar
from grammarllm.scripts.map_terminal_tokens import generate_token_maps
from grammarllm.scripts.table_parsing import parsing_table
from grammarllm.modules.BaseStreamer import BaseStreamer
from grammarllm.modules.PushdownAutomaton import PushdownAutomaton
from grammarllm.modules.SimpleLogitProcessor import MaskLogitsProcessor
import logging
import re
import os
from collections import defaultdict
from tqdm import tqdm
from grammarllm.utils.common_regex import regex_dict
from grammarllm.utils.examples import *
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import json
import zipfile
import spaces
import torch
from huggingface_hub import login
login(token = os.getenv("llama_acces_token"))
def pipeline(words, tokenizer, lhs, count=0, non_terminals=None, FINAL_RULES=None): #questa Γ¨ + un preprocessing di ogni produzione nella rules
"""
Process input words to generate context-free grammar rules.
This function implements a pipeline for creating grammar rules from a set of words
or phrases. It processes the input through several stages: tokenization, state
transition building, prefix grouping, non-terminal generation, and grammar rule creation.
The generated rules are added to a master set of rules.
Args:
words (list): Collection of words or phrases to process.
tokenizer: Tokenizer object used to convert words into tokens.
lhs (str): Left-hand side symbol for grammar rules.
count (int, optional): Counter for unique non-terminal generation. Defaults to 0, used to handle apices in NT rules.
non_terminals (list, optional): Predefined non-terminals to use.
FINAL_RULES (dict, optional): Existing grammar rules to extend.
Returns:
tuple: A tuple containing:
- FINAL_RULES (dict): Updated dictionary of grammar rules.
- count (int): Updated counter value for non-terminal generation.
Dependencies:
- build_SState: Creates state transitions from input words
- group_by_prefix: Groups transitions by their prefixes
- generate_non_terminals: Creates non-terminal symbols
- generate_grammar: Generates grammar rules
"""
def build_SState(classes, tokenizer):
SState = []
tokenized_classes = [tokenizer.tokenize(c) for c in classes]
glob_count = 1
pbar = tqdm(total=len(classes), desc="Build state")
for tok_class in tokenized_classes:
state = 0
for token in tok_class:
if token not in SState: #provare a togliere questo if se non necessario!
SState.append((state,token,glob_count))
glob_count += 1
state += 1
pbar.update(1)
pbar.close()
logging.info(SState)
#print(list(SState))
return SState
def group_by_prefix(transitions):
"""Group transitions by their state and prefix"""
grammar = defaultdict(list)
# Build transition map
for state, symbol, end in transitions:
grammar[state].append((symbol, end))
# Group by state and prefix
grouped = defaultdict(lambda: defaultdict(list))
for state, transitions_list in grammar.items():
for symbol, end in transitions_list:
grouped[state][symbol].append((symbol, end))
return grouped
tansitions = build_SState(words, tokenizer)
grouped_data = group_by_prefix(tansitions)
#Generate non-terminals
G,S = generate_non_terminals(grouped_data,count=count)
count+=1 #aggiunto x la question degli apici
#tokenizer.eos_token
grammar_rules = generate_grammar(G, S, NT=lhs, eos_symbol='|eot|', non_terminals_list=non_terminals)
for key, values in grammar_rules.items():
if key in FINAL_RULES:
FINAL_RULES[key].extend(values)
else:
FINAL_RULES[key] = values
logging.info("\nGrouped Data:")
for state, prefixes in grouped_data.items():
logging.info(f"State {state}:")
for prefix, class_labels_list in prefixes.items():
logging.info(f" {prefix} -> {class_labels_list}")
logging.info("\n Generated Non-Terminals:\n")
for nt, prefix in G.items():
logging.info(f"{nt} -> {prefix}")
logging.info("\n Ends Non-Terminals:\n")
for nt, prefix in S.items():
logging.info(f"{nt} -> {prefix}")
logging.info("\nGrammar Rules:\n")
for nt, rules in grammar_rules.items():
for rule in rules:
logging.info(f"{rule}")
return FINAL_RULES,count
def process_grammar_rules(productions, tokenizer):# forse Γ¨ + una pipeline che poi porta alla final_rueles, infatti chiama la pipeline_for_general
"""
Process grammar production rules based on the specified task.
This function iterates through production rules and handles them differently
based on whether the task is 'Classification'/'VR' or 'General'. For general tasks,
it separates rules with None tags for direct assignment and processes the rest.
Args:
productions (dict): Dictionary of grammar production rules
tokenizer: Tokenizer to use for processing
Returns:
dict: Final grammar rules
"""
def extract_tags_and_others(rhs_list):
tags_list = []
others_list = []
tag_pattern = re.compile(r'<<(.+?)>>')
def smart_split(item):
# Trova tutti i tag <<...>> e separa il resto del testo
matches = list(tag_pattern.finditer(item))
parts = []
last_index = 0
for match in matches:
# Aggiungi il testo prima del tag, splittato
pre_text = item[last_index:match.start()]
parts.extend(pre_text.strip().split())
# Aggiungi il tag intero come una sola unitΓ 
parts.append(match.group(0))
last_index = match.end()
# Aggiungi eventuale testo dopo l'ultimo tag
post_text = item[last_index:]
parts.extend(post_text.strip().split())
return parts
for item in rhs_list:
tags = []
others = []
if re.search(tag_pattern, item):
words = smart_split(item)
current_chunk = []
for word in words:
match = re.fullmatch(tag_pattern, word)
if match:
tags.append(match.group(1)) # salva solo il contenuto del tag
else:
current_chunk.append(word)
if current_chunk:
others.append(' '.join(current_chunk))
else:
others.append(None)
tags_list.append(tags)
others_list.append(others)
else:
tags_list.append([None])
others_list.append([item])
return tags_list, others_list
final_rules = {}
count = 0
for lhs, rhs_list in productions.items():
tags_list, non_terminals_list = extract_tags_and_others(rhs_list)
filtered_tags = []
filtered_non_terminals = []
for j in range(len(tags_list)):
tag_group = tags_list[j]
non_terminal_group = non_terminals_list[j]
if any(tag is not None for tag in tag_group):
# Filter out None tags and add them directly to final_rules
i = 0
while i < len(tag_group):
if tag_group[i] is None:
# Add rule directly to final_rules
if lhs in final_rules:
final_rules[lhs].append(rhs_list[i])
else:
final_rules[lhs] = [rhs_list[i]]
# Remove processed tag and non-terminal
tag_group.pop(i)
non_terminal_group.pop(i)
else:
# Keep tag and non-terminal for further processing
filtered_tags.append(tag_group[i])
if i < len(non_terminal_group):
filtered_non_terminals.append(non_terminal_group[i])
i += 1
else:
# All tags are None, add rules directly
final_rules.update({lhs: rhs_list})
#print(f"Filtered tags: {filtered_tags}") #DEBUG
#print(f"Filtered non-terminals: {filtered_non_terminals}")#DEBUG
# Process remaining tags through the general pipeline
if filtered_tags:
final_rules, count = pipeline(
filtered_tags, tokenizer, lhs,
count=count,
non_terminals=filtered_non_terminals,
FINAL_RULES=final_rules
)
return final_rules, count
def get_parsing_table_and_map_tt(tokenizer, productions=None, regex_dict=None):
def write_grammar_to_file(grammar_rules):
output_file = os.path.join('temp','grammar_rules.txt')
os.makedirs(os.path.dirname(output_file), exist_ok=True)
"""Write grammar rules to a file"""
with open(output_file, 'w') as f:
for non_terminal, rules in grammar_rules.items():
for rule in rules:
f.write(f"{non_terminal} -> {rule}\n")
f.write("\n")
logging.info(f"\nGrammar Rules to {output_file}")
# Get final grammar rules
final_rules, _ = process_grammar_rules(productions, tokenizer)
#print(final_rules) #DEBUG
write_grammar_to_file(final_rules)
logging.info(final_rules)
# Generate parsing table
pars_tab = parsing_table(final_rules)
# Generate token maps
if regex_dict:
map_terminal_tokens = generate_token_maps(tokenizer, pars_tab, regex_dict)
else:
map_terminal_tokens = generate_token_maps(tokenizer, pars_tab)
logging.info("\nMap Terminal Tokens:\n")
for key, values in map_terminal_tokens.items():
logging.info(f"{key} -> {values}")
return pars_tab, map_terminal_tokens
def generate_grammar_parameters(tokenizer, pars_tab, map_terminal_tokens):
# Create Pushdown Automaton and initialize processors and streamer
pda = PushdownAutomaton(grammar=pars_tab, startSymbol='S*', map=map_terminal_tokens)
return MaskLogitsProcessor(tokenizer, pda), BaseStreamer(tokenizer, pda)
def setup_logging():
"""Setup logging configuration."""
log_dir = 'temp'
os.makedirs(log_dir, exist_ok=True) # Ensure the log directory exists
logging.basicConfig(
filename=os.path.join(log_dir, 'GRAM-GEN.log'),
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
filemode='w+' # Overwrites the file every time
)
def generate_text(model, tokenizer, text, logit_processor, streamer, max_new_tokens=400, do_sample=False, temperature=None, top_p=None, **kwargs):
"""
Genera testo vincolato dalla grammatica, con configurazione dei parametri di generazione sicura.
Args:
model: Il modello pre-addestrato.
tokenizer: Il tokenizer del modello.
text: Input text iniziale.
logit_processor: Processor dei logit basato sulla grammatica.
streamer: Streamer per l'output live.
max_new_tokens: Numero massimo di nuovi token da generare.
do_sample: Se True, abilita la generazione stocastica.
temperature: Controlla la casualitΓ  (usato solo se do_sample=True).
top_p: Top-p (nucleus sampling), usato solo se do_sample=True.
**kwargs: Parametri aggiuntivi opzionali per model.generate().
"""
try:
tokenized_input = tokenizer(text, return_tensors="pt")
# Safe defaults
kwargs.setdefault("num_beams", 1) # beam search disattivato
kwargs.setdefault("pad_token_id", tokenizer.eos_token_id)
# Sicurezza num_beams
if kwargs["num_beams"] != 1:
logging.warning("⚠️ num_beams > 1 non è compatibile con la generazione vincolata da grammatica. Impostato automaticamente a num_beams=1.")
kwargs["num_beams"] = 1
# Sampling parameters
if do_sample:
if temperature is not None:
kwargs["temperature"] = temperature
if top_p is not None:
kwargs["top_p"] = top_p
else:
# Rimuovi parametri di sampling se presenti
kwargs.pop("temperature", None)
kwargs.pop("top_p", None)
# Device compatibility
device = model.device
input_ids = tokenized_input["input_ids"].to(device)
if input_ids.device != model.device:
logging.warning("Errore: gli 'input_ids' sono sulla device {input_ids.device}, mentre il modello Γ¨ sulla device {model.device}. Spostando 'input_ids' sulla stessa device del modello.")
attention_mask = tokenized_input["attention_mask"].to(device)
if attention_mask.device != model.device:
logging.warning(f"Errore: l'attention_mask Γ¨ sulla device {attention_mask.device}, mentre il modello Γ¨ sulla device {model.device}. Spostando 'attention_mask' sulla stessa device del modello.")
start = input_ids.shape[1]
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
streamer=streamer,
logits_processor=[logit_processor],
**kwargs
)
answer = tokenizer.decode(output[0][start:], skip_special_tokens=True)
return answer
except Exception as e:
raise RuntimeError(f"Errore nella generazione del testo: {e}")
@spaces.GPU
def run_grammarllm(prompt, productions_json, model_choice,regex_json):
setup_logging()
# Parsing productions
try:
productions = json.loads(productions_json)
except json.JSONDecodeError:
return "Errore: JSON productions non valido.", None
# Regex fissa, non caricata dall'utente
regex_raw = {
"regex_alfanum": "[a-zA-Z0-9]+",
"regex_letters": "[a-zA-Z]+",
"regex_number": "\\d+",
"regex_decimal": "\\d+([.,]\\d+)?",
"regex_var": "[a-zA-Z_][a-zA-Z0-9_]*",
"regex_)": "\\)",
"regex_(": "\\("
}
try:
regex_raw = json.loads(regex_json)
regex_dict = {key: re.compile(pattern) for key, pattern in regex_raw.items()}
except (json.JSONDecodeError, re.error) as e:
return f"Errore nelle regex personalizzate: {str(e)}", None
try:
# Selezione del modello basata sulla scelta dell'utente
if model_choice == "GPT-2":
model_name = "gpt2"
elif model_choice == "Llama 3.2 1B":
model_name = "meta-llama/Llama-3.2-1B-Instruct"
#elif model_choice == "Llama 3.1 8B":
# model_name = "meta-llama/Llama-3.1-8B-Instruct"
else:
return f"Modello non supportato: {model_choice}", None
# Caricamento del tokenizer e del modello
print(f"Caricamento del modello: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Configurazione del device e dtype per ottimizzare le prestazioni
device = "cuda" if torch.cuda.is_available() else "cpu"
if model_choice.startswith("Llama"):
# Per i modelli Llama, usa torch_dtype=torch.float16 per risparmiare memoria
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
else:
# Per GPT-2
model = AutoModelForCausalLM.from_pretrained(model_name)
model = model.to(device)
# Aggiungi pad_token se non esiste
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
pars_table, map_terminal_tokens = get_parsing_table_and_map_tt(
tokenizer,
productions=productions,
regex_dict=regex_dict,
)
LogitProcessor, Streamer = generate_grammar_parameters(tokenizer, pars_table, map_terminal_tokens)
output = generate_text(model, tokenizer, prompt, LogitProcessor, Streamer)
# Creazione del file ZIP
temp_dir = "./temp"
zip_path = temp_dir + ".zip"
# Assicurati che temp_dir esista
if os.path.exists(temp_dir):
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for root, dirs, files in os.walk(temp_dir):
for file in files:
file_path = os.path.join(root, file)
arcname = os.path.relpath(file_path, temp_dir)
zipf.write(file_path, arcname)
else:
zip_path = None
# Libera la memoria del modello
del model
torch.cuda.empty_cache() if torch.cuda.is_available() else None
return output, zip_path
except Exception as e:
return f"Errore durante l'inferenza: {str(e)}", None
default_grammars = {
"HC Grammar": json.dumps({
"S*": ["<<positive>> A", "<<negative>> B", "<<neutral>> C"],
"A": ["<<happy>> D", "<<peaceful>> E", "<<joyful>> F"],
"B": ["<<sad>>", "<<angry>>", "<<frustrated>>"],
"C": ["<<calm>>", "<<indifferent>>", "<<unemotional>>"],
"D": ["<<enthusiastic>>"],
"E": ["<<content>>"],
"F": ["<<excited>>"]
}, indent=4),
"VR Grammar": json.dumps({
"S*": ["<<positive>> S*", "<<negative>> S*", "<<neutral>> S*"],
}, indent=4),
"General Grammar": json.dumps({
'S*': ["( LETTERS )"],
'LETTERS': ['letters number LETTERS',"Ξ΅"]
}, indent=4),
}
def update_productions(grammar_choice):
# Aggiorna textbox productions al cambio preset
return default_grammars[grammar_choice]
def load_file(file_obj):
if file_obj is None:
return "Errore: nessun file caricato."
try:
# In newer Gradio versions, file_obj is a path string, not a file object
if isinstance(file_obj, str):
# file_obj is the file path
with open(file_obj, 'r', encoding='utf-8') as f:
content = f.read()
else:
# Fallback for older Gradio versions or different file object types
if hasattr(file_obj, 'name'):
# file_obj has a 'name' attribute containing the path
with open(file_obj.name, 'r', encoding='utf-8') as f:
content = f.read()
else:
# Try to read directly (old behavior)
content = file_obj.read().decode("utf-8")
json.loads(content) # controlla che sia JSON valido
return content
except Exception as e:
return f"Errore nel caricamento file: {str(e)}"
# Interfaccia Gradio migliorata
with gr.Blocks(title="GrammarLLM - enable structured generation via formal language") as demo:
gr.Markdown("# GrammarLLM - enable structured generation via LLprefix")
gr.Markdown("")
with gr.Row():
with gr.Column(scale=2):
prompt_input = gr.Textbox(
label="Insert your prompt",
placeholder="Type here your prompt...",
lines=3
)
with gr.Column(scale=1):
model_choice = gr.Dropdown(
choices=["GPT-2", "Llama 3.2 1B"],#, "Llama 3.1 8B"],
label="Choose the model",
value="GPT-2",
interactive=True
)
with gr.Row():
with gr.Column():
grammar_choice = gr.Dropdown(
list(default_grammars.keys()),
label="Choose Productions (JSON)",
value="HC Grammar",
interactive=True,
elem_id="grammar_choice"
)
with gr.Column():
productions_upload = gr.File(
label="Upload file Productions (JSON)",
file_types=['.json']
)
productions_text = gr.Textbox(
label="Productions (JSON)",
lines=15,
value=default_grammars["HC Grammar"],
info="Type your here your grammar in json fromat"
)
regex_text = gr.Textbox(
label="Regex to define Terminals (JSON)",
lines=10,
value=json.dumps({
"regex_alfanum": "[a-zA-Z0-9]+",
"regex_letters": "[a-zA-Z]+",
"regex_number": "\\d+",
"regex_decimal": "\\d+([.,]\\d+)?",
"regex_var": "[a-zA-Z_][a-zA-Z0-9_]*",
"regex_)": "\\)",
"regex_(": "\\("
}, indent=4),
info="Modify these common regex"
)
with gr.Row():
submit_btn = gr.Button("πŸš€ Generate Output", variant="primary", size="lg")
clear_btn = gr.Button("πŸ—‘οΈ Clean", variant="secondary")
with gr.Row():
with gr.Column():
output_text = gr.Textbox(
label="Output generated",
lines=10,
show_copy_button=True
)
with gr.Column():
zip_file = gr.File(label="πŸ“¦ Download ZIP (if available)")
with gr.Accordion("ℹ️ About GrammarLLM and LLprefix", open=False):
gr.Markdown("""
### πŸ“š What is GrammarLLM?
GrammarLLM enables structured text generation constrained by a formal grammar, using LLMs (Large Language Models) such as GPT-2 or LLaMA.
### πŸ” What you can do:
- **Hierarchical classification**: Define class hierarchies, as shown in the "HC Grammar" example.
- **Vocabulary restriction**: Specify a limited set of valid words to be used. Including examples in the prompt is highly recommended to improve output quality.
- **Constrained generation**: Use LLprefix to define any regular or context-free grammar in JSON format.
πŸ“„ For more details about LLprefix and the underlying algorithms, refer to the official paper.
""")
# Callback: quando cambio dropdown, aggiorno productions_text
grammar_choice.change(
fn=update_productions,
inputs=grammar_choice,
outputs=productions_text,
)
# Callback: quando carico file productions, aggiorno productions_text (override dropdown)
productions_upload.upload(
fn=load_file,
inputs=productions_upload,
outputs=productions_text,
)
# Al submit del form chiamo run_grammarllm
submit_btn.click(
fn=run_grammarllm,
inputs=[prompt_input, productions_text, model_choice, regex_text],
outputs=[output_text, zip_file],
show_progress=True
)
# Funzione per pulire i campi
def clear_fields():
return "", default_grammars["HC"], "", None, json.dumps({
"regex_alfanum": "[a-zA-Z0-9]+",
"regex_letters": "[a-zA-Z]+",
"regex_number": "\\d+",
"regex_decimal": "\\d+([.,]\\d+)?",
"regex_var": "[a-zA-Z_][a-zA-Z0-9_]*",
"regex_)": "\\)",
"regex_(": "\\("
}, indent=4)
clear_btn.click(
fn=clear_fields,
outputs=[prompt_input, productions_text, output_text, zip_file, regex_text]
)
if __name__ == "__main__":
demo.launch()