Gabriele Tuccio commited on
Commit
cbb9121
·
1 Parent(s): 4e0abc9
Files changed (2) hide show
  1. app.py +197 -71
  2. requirements.txt +2 -1
app.py CHANGED
@@ -371,8 +371,36 @@ def generate_text(model, tokenizer, text, logit_processor, streamer, max_new_tok
371
  raise RuntimeError(f"Errore nella generazione del testo: {e}")
372
 
373
 
 
 
 
 
 
 
 
 
 
 
 
374
 
375
- def run_grammarllm(prompt, productions_json, regex_json):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
  setup_logging()
377
 
378
  # Parsing productions
@@ -381,75 +409,113 @@ def run_grammarllm(prompt, productions_json, regex_json):
381
  except json.JSONDecodeError:
382
  return "Errore: JSON productions non valido.", None
383
 
384
- # Parsing regex_dict
 
 
 
 
 
 
 
 
 
 
385
  try:
386
- regex_raw = json.loads(regex_json)
387
  regex_dict = {key: re.compile(pattern) for key, pattern in regex_raw.items()}
388
- except json.JSONDecodeError:
389
- return "Errore: JSON regex non valido.", None
390
  except re.error as e:
391
  return f"Errore nella compilazione regex: {str(e)}", None
392
 
393
  try:
394
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
395
- model = AutoModelForCausalLM.from_pretrained("gpt2")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
  pars_table, map_terminal_tokens = get_parsing_table_and_map_tt(
398
- tokenizer,
399
- productions=productions,
400
  regex_dict=regex_dict,
401
  )
402
 
403
  LogitProcessor, Streamer = generate_grammar_parameters(tokenizer, pars_table, map_terminal_tokens)
404
  output = generate_text(model, tokenizer, prompt, LogitProcessor, Streamer)
405
 
 
406
  temp_dir = "./temp"
407
  zip_path = temp_dir + ".zip"
408
 
409
- with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
410
- for root, dirs, files in os.walk(temp_dir):
411
- for file in files:
412
- file_path = os.path.join(root, file)
413
- arcname = os.path.relpath(file_path, temp_dir)
414
- zipf.write(file_path, arcname)
 
 
 
 
 
 
 
 
415
 
416
  return output, zip_path
417
 
418
  except Exception as e:
419
  return f"Errore durante l'inferenza: {str(e)}", None
420
 
 
421
  default_grammars = {
422
  "Default Grammar": json.dumps({
423
- "S*": ["<<positive>> A", "<<negative>> B", "<<neutral>> C"],
424
- "A": ["<<happy>> D", "<<peaceful>> E", "<<joyful>> F"],
425
- "B": ["<<sad>>", "<<angry>>", "<<frustrated>>"],
426
- "C": ["<<calm>>", "<<indifferent>>", "<<unemotional>>"],
427
- "D": ["<<enthusiastic>>"],
428
- "E": ["<<content>>"],
429
- "F": ["<<excited>>"]
430
  }, indent=4),
431
 
432
  "Other example": json.dumps({
433
- 'S*': ["<<(>> A B", "<<negligent>> V", '<<indifferent>>'],
434
- 'A': ["number", "letters", "ε"],
435
- 'B': ['<<)>> letters R'],
436
- 'R': ['C', 'D'],
437
- 'C': ['<<calm>>', '<<indifferent>>', '<<unemotional>>'],
438
- 'D': ['<<angry>>', '<<frustrated>>'],
439
- 'V': ["<<option>>"],
440
  }, indent=4),
441
  }
442
 
443
- default_regex_json = json.dumps({
444
- "regex_alfanum": "[a-zA-Z0-9]+",
445
- "regex_letters": "[a-zA-Z]+",
446
- "regex_number": "\\d+",
447
- "regex_decimal": "\\d+([.,]\\d+)?",
448
- "regex_var": "[a-zA-Z_][a-zA-Z0-9_]*",
449
- "regex_)": "\\)",
450
- "regex_(": "\\("
451
- }, indent=4)
452
-
453
 
454
  def update_productions(grammar_choice):
455
  # Aggiorna textbox productions al cambio preset
@@ -458,39 +524,98 @@ def update_productions(grammar_choice):
458
 
459
  def load_file(file_obj):
460
  if file_obj is None:
461
- return ""
462
  try:
463
- content = file_obj.read().decode("utf-8")
464
- # opzionale: validare JSON?
465
- json.loads(content)
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  return content
467
  except Exception as e:
468
  return f"Errore nel caricamento file: {str(e)}"
469
 
470
 
471
- with gr.Blocks() as demo:
472
-
473
- prompt_input = gr.Textbox(label="Inserisci prompt testuale")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
 
475
  with gr.Row():
476
- grammar_choice = gr.Dropdown(
477
- list(default_grammars.keys()),
478
- label="Scegli Productions (JSON)",
479
- value="Default Grammar",
480
- interactive=True,
481
- elem_id="grammar_choice"
482
- )
483
- productions_upload = gr.File(label="Carica file Productions (JSON)", file_types=['.json'])
484
-
485
- productions_text = gr.Textbox(label="Productions (JSON)", lines=10, value=default_grammars["Default Grammar"])
 
 
 
 
 
 
 
 
 
 
 
486
 
487
  with gr.Row():
488
- regex_upload = gr.File(label="Carica file Regex_dict (JSON)", file_types=['.json'])
489
-
490
- regex_text = gr.Textbox(label="Inserisci regex_dict (JSON)", lines=10, value=default_regex_json)
491
 
492
- output_text = gr.Textbox(label="Output generato")
493
- zip_file = gr.File(label="Scarica ZIP")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
 
495
  # Callback: quando cambio dropdown, aggiorno productions_text
496
  grammar_choice.change(
@@ -506,20 +631,21 @@ with gr.Blocks() as demo:
506
  outputs=productions_text,
507
  )
508
 
509
- # Callback: quando carico file regex, aggiorno regex_text
510
- regex_upload.upload(
511
- fn=load_file,
512
- inputs=regex_upload,
513
- outputs=regex_text,
514
- )
515
-
516
  # Al submit del form chiamo run_grammarllm
517
- submit_btn = gr.Button("Genera output")
518
-
519
  submit_btn.click(
520
  fn=run_grammarllm,
521
- inputs=[prompt_input, productions_text, regex_text],
522
  outputs=[output_text, zip_file],
 
 
 
 
 
 
 
 
 
 
523
  )
524
 
525
  if __name__ == "__main__":
 
371
  raise RuntimeError(f"Errore nella generazione del testo: {e}")
372
 
373
 
374
+ import gradio as gr
375
+ import json
376
+ import re
377
+ import os
378
+ import zipfile
379
+ import spaces
380
+ from transformers import AutoTokenizer, AutoModelForCausalLM
381
+ import torch
382
+
383
+ # Assumendo che queste funzioni esistano nel tuo modulo
384
+ # from your_module import get_parsing_table_and_map_tt, generate_grammar_parameters, generate_text, setup_logging
385
 
386
+ def setup_logging():
387
+ # Implementa il tuo setup di logging qui
388
+ pass
389
+
390
+ def get_parsing_table_and_map_tt(tokenizer, productions, regex_dict):
391
+ # Implementa la tua logica qui
392
+ pass
393
+
394
+ def generate_grammar_parameters(tokenizer, pars_table, map_terminal_tokens):
395
+ # Implementa la tua logica qui
396
+ pass
397
+
398
+ def generate_text(model, tokenizer, prompt, LogitProcessor, Streamer):
399
+ # Implementa la tua logica qui
400
+ pass
401
+
402
+ @spaces.GPU
403
+ def run_grammarllm(prompt, productions_json, model_choice):
404
  setup_logging()
405
 
406
  # Parsing productions
 
409
  except json.JSONDecodeError:
410
  return "Errore: JSON productions non valido.", None
411
 
412
+ # Regex fissa, non caricata dall'utente
413
+ regex_raw = {
414
+ "regex_alfanum": "[a-zA-Z0-9]+",
415
+ "regex_letters": "[a-zA-Z]+",
416
+ "regex_number": "\\d+",
417
+ "regex_decimal": "\\d+([.,]\\d+)?",
418
+ "regex_var": "[a-zA-Z_][a-zA-Z0-9_]*",
419
+ "regex_)": "\\)",
420
+ "regex_(": "\\("
421
+ }
422
+
423
  try:
 
424
  regex_dict = {key: re.compile(pattern) for key, pattern in regex_raw.items()}
 
 
425
  except re.error as e:
426
  return f"Errore nella compilazione regex: {str(e)}", None
427
 
428
  try:
429
+ # Selezione del modello basata sulla scelta dell'utente
430
+ if model_choice == "GPT-2":
431
+ model_name = "gpt2"
432
+ elif model_choice == "Llama 3.2 3B":
433
+ model_name = "meta-llama/Llama-3.2-3B"
434
+ elif model_choice == "Llama 3.2 1B":
435
+ model_name = "meta-llama/Llama-3.2-1B"
436
+ else:
437
+ return f"Modello non supportato: {model_choice}", None
438
+
439
+ # Caricamento del tokenizer e del modello
440
+ print(f"Caricamento del modello: {model_name}")
441
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
442
+
443
+ # Configurazione del device e dtype per ottimizzare le prestazioni
444
+ device = "cuda" if torch.cuda.is_available() else "cpu"
445
+
446
+ if model_choice.startswith("Llama"):
447
+ # Per i modelli Llama, usa torch_dtype=torch.float16 per risparmiare memoria
448
+ model = AutoModelForCausalLM.from_pretrained(
449
+ model_name,
450
+ torch_dtype=torch.float16,
451
+ device_map="auto",
452
+ trust_remote_code=True
453
+ )
454
+ else:
455
+ # Per GPT-2
456
+ model = AutoModelForCausalLM.from_pretrained(model_name)
457
+ model = model.to(device)
458
+
459
+ # Aggiungi pad_token se non esiste
460
+ if tokenizer.pad_token is None:
461
+ tokenizer.pad_token = tokenizer.eos_token
462
 
463
  pars_table, map_terminal_tokens = get_parsing_table_and_map_tt(
464
+ tokenizer,
465
+ productions=productions,
466
  regex_dict=regex_dict,
467
  )
468
 
469
  LogitProcessor, Streamer = generate_grammar_parameters(tokenizer, pars_table, map_terminal_tokens)
470
  output = generate_text(model, tokenizer, prompt, LogitProcessor, Streamer)
471
 
472
+ # Creazione del file ZIP
473
  temp_dir = "./temp"
474
  zip_path = temp_dir + ".zip"
475
 
476
+ # Assicurati che temp_dir esista
477
+ if os.path.exists(temp_dir):
478
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
479
+ for root, dirs, files in os.walk(temp_dir):
480
+ for file in files:
481
+ file_path = os.path.join(root, file)
482
+ arcname = os.path.relpath(file_path, temp_dir)
483
+ zipf.write(file_path, arcname)
484
+ else:
485
+ zip_path = None
486
+
487
+ # Libera la memoria del modello
488
+ del model
489
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
490
 
491
  return output, zip_path
492
 
493
  except Exception as e:
494
  return f"Errore durante l'inferenza: {str(e)}", None
495
 
496
+
497
  default_grammars = {
498
  "Default Grammar": json.dumps({
499
+ "S*": ["<<positive>> A", "<<negative>> B", "<<neutral>> C"],
500
+ "A": ["<<happy>> D", "<<peaceful>> E", "<<joyful>> F"],
501
+ "B": ["<<sad>>", "<<angry>>", "<<frustrated>>"],
502
+ "C": ["<<calm>>", "<<indifferent>>", "<<unemotional>>"],
503
+ "D": ["<<enthusiastic>>"],
504
+ "E": ["<<content>>"],
505
+ "F": ["<<excited>>"]
506
  }, indent=4),
507
 
508
  "Other example": json.dumps({
509
+ 'S*': ["<<(>> A B", "<<negligent>> V", '<<indifferent>>'],
510
+ 'A': ["number", "letters", "ε"],
511
+ 'B': ['<<)>> letters R'],
512
+ 'R': ['C', 'D'],
513
+ 'C': ['<<calm>>', '<<indifferent>>', '<<unemotional>>'],
514
+ 'D': ['<<angry>>', '<<frustrated>>'],
515
+ 'V': ["<<option>>"],
516
  }, indent=4),
517
  }
518
 
 
 
 
 
 
 
 
 
 
 
519
 
520
  def update_productions(grammar_choice):
521
  # Aggiorna textbox productions al cambio preset
 
524
 
525
  def load_file(file_obj):
526
  if file_obj is None:
527
+ return "Errore: nessun file caricato."
528
  try:
529
+ # In newer Gradio versions, file_obj is a path string, not a file object
530
+ if isinstance(file_obj, str):
531
+ # file_obj is the file path
532
+ with open(file_obj, 'r', encoding='utf-8') as f:
533
+ content = f.read()
534
+ else:
535
+ # Fallback for older Gradio versions or different file object types
536
+ if hasattr(file_obj, 'name'):
537
+ # file_obj has a 'name' attribute containing the path
538
+ with open(file_obj.name, 'r', encoding='utf-8') as f:
539
+ content = f.read()
540
+ else:
541
+ # Try to read directly (old behavior)
542
+ content = file_obj.read().decode("utf-8")
543
+
544
+ json.loads(content) # controlla che sia JSON valido
545
  return content
546
  except Exception as e:
547
  return f"Errore nel caricamento file: {str(e)}"
548
 
549
 
550
+ # Interfaccia Gradio migliorata
551
+ with gr.Blocks(title="GrammarLLM - Inferenza Guidata da Grammatica") as demo:
552
+
553
+ gr.Markdown("# GrammarLLM - Generazione di Testo Guidata da Grammatica")
554
+ gr.Markdown("Genera testo strutturato utilizzando grammatiche personalizzate con supporto per GPT-2 e modelli Llama.")
555
+
556
+ with gr.Row():
557
+ with gr.Column(scale=2):
558
+ prompt_input = gr.Textbox(
559
+ label="Inserisci prompt testuale",
560
+ placeholder="Scrivi qui il tuo prompt...",
561
+ lines=3
562
+ )
563
+
564
+ with gr.Column(scale=1):
565
+ model_choice = gr.Dropdown(
566
+ choices=["GPT-2", "Llama 3.2 1B", "Llama 3.2 3B"],
567
+ label="Scegli Modello",
568
+ value="GPT-2",
569
+ interactive=True
570
+ )
571
 
572
  with gr.Row():
573
+ with gr.Column():
574
+ grammar_choice = gr.Dropdown(
575
+ list(default_grammars.keys()),
576
+ label="Scegli Productions (JSON)",
577
+ value="Default Grammar",
578
+ interactive=True,
579
+ elem_id="grammar_choice"
580
+ )
581
+
582
+ with gr.Column():
583
+ productions_upload = gr.File(
584
+ label="Carica file Productions (JSON)",
585
+ file_types=['.json']
586
+ )
587
+
588
+ productions_text = gr.Textbox(
589
+ label="Productions (JSON)",
590
+ lines=15,
591
+ value=default_grammars["Default Grammar"],
592
+ info="Modifica direttamente la grammatica in formato JSON"
593
+ )
594
 
595
  with gr.Row():
596
+ submit_btn = gr.Button("🚀 Genera Output", variant="primary", size="lg")
597
+ clear_btn = gr.Button("🗑️ Pulisci", variant="secondary")
 
598
 
599
+ with gr.Row():
600
+ with gr.Column():
601
+ output_text = gr.Textbox(
602
+ label="Output generato",
603
+ lines=10,
604
+ show_copy_button=True
605
+ )
606
+
607
+ with gr.Column():
608
+ zip_file = gr.File(label="📦 Scarica ZIP (se disponibile)")
609
+
610
+ # Informazioni sui modelli
611
+ with gr.Accordion("ℹ️ Informazioni sui Modelli", open=False):
612
+ gr.Markdown("""
613
+ - **GPT-2**: Modello classico, veloce e leggero
614
+ - **Llama 3.2 1B**: Modello più recente e performante, dimensione ridotta
615
+ - **Llama 3.2 3B**: Modello più grande e capace, richiede più risorse
616
+
617
+ *Nota: I modelli Llama utilizzano Zero GPU per l'accelerazione automatica.*
618
+ """)
619
 
620
  # Callback: quando cambio dropdown, aggiorno productions_text
621
  grammar_choice.change(
 
631
  outputs=productions_text,
632
  )
633
 
 
 
 
 
 
 
 
634
  # Al submit del form chiamo run_grammarllm
 
 
635
  submit_btn.click(
636
  fn=run_grammarllm,
637
+ inputs=[prompt_input, productions_text, model_choice],
638
  outputs=[output_text, zip_file],
639
+ show_progress=True
640
+ )
641
+
642
+ # Funzione per pulire i campi
643
+ def clear_fields():
644
+ return "", default_grammars["Default Grammar"], None, None
645
+
646
+ clear_btn.click(
647
+ fn=clear_fields,
648
+ outputs=[prompt_input, productions_text, output_text, zip_file]
649
  )
650
 
651
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -3,4 +3,5 @@ tqdm
3
  transformers
4
  setuptools
5
  accelerate>=0.26.0
6
- gradio
 
 
3
  transformers
4
  setuptools
5
  accelerate>=0.26.0
6
+ gradio
7
+ spaces