Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import string | |
| import contractions | |
| import datasets | |
| import evaluate | |
| import pandas as pd | |
| import torch | |
| from datasets import Dataset | |
| from tqdm import tqdm | |
| from transformers import (AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer, | |
| DataCollatorForSeq2Seq, Seq2SeqTrainer, | |
| Seq2SeqTrainingArguments) | |
| def clean_text(texts): | |
| """This fonction makes clean text for the future use""" | |
| texts = texts.lower() | |
| texts = contractions.fix(texts) | |
| texts = texts.translate(str.maketrans("", "", string.punctuation)) | |
| texts = re.sub(r"\n", " ", texts) | |
| return texts | |
| def datasetmaker(path=str): | |
| """This fonction take the jsonl file, read it to a dataframe, | |
| remove the colums not needed for the task and turn it into a file type Dataset | |
| """ | |
| data = pd.read_json(path, lines=True) | |
| df = data.drop( | |
| [ | |
| "url", | |
| "archive", | |
| "title", | |
| "date", | |
| "compression", | |
| "coverage", | |
| "density", | |
| "compression_bin", | |
| "coverage_bin", | |
| "density_bin", | |
| ], | |
| axis=1, | |
| ) | |
| tqdm.pandas() | |
| df["text"] = df.text.apply(lambda texts: clean_text(texts)) | |
| df["summary"] = df.summary.apply(lambda summary: clean_text(summary)) | |
| dataset = Dataset.from_dict(df) | |
| return dataset | |
| # voir si le model par hasard esr déjà bien | |
| # test_text = dataset['text'][0] | |
| # pipe = pipeline('summarization', model = model_ckpt) | |
| # pipe_out = pipe(test_text) | |
| # print(pipe_out[0]['summary_text'].replace('.<n>', '.\n')) | |
| # print(dataset['summary'][0]) | |
| def generate_batch_sized_chunks(list_elements, batch_size): | |
| """this fonction split the dataset into smaller batches | |
| that we can process simultaneously | |
| Yield successive batch-sized chunks from list_of_elements.""" | |
| for i in range(0, len(list_elements), batch_size): | |
| yield list_elements[i: i + batch_size] | |
| def calculate_metric(dataset, metric, model, tokenizer, | |
| batch_size, device, | |
| column_text='text', | |
| column_summary='summary'): | |
| """this fonction evaluate the model with metric rouge and | |
| print a table of rouge scores rouge1', 'rouge2', 'rougeL', 'rougeLsum'""" | |
| article_batches = list( | |
| str(generate_batch_sized_chunks(dataset[column_text], batch_size)) | |
| ) | |
| target_batches = list( | |
| str(generate_batch_sized_chunks(dataset[column_summary], batch_size)) | |
| ) | |
| for article_batch, target_batch in tqdm( | |
| zip(article_batches, target_batches), total=len(article_batches) | |
| ): | |
| inputs = tokenizer( | |
| article_batch, | |
| max_length=1024, | |
| truncation=True, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| # parameter for length penalty ensures that the model does not | |
| # generate sequences that are too long. | |
| summaries = model.generate( | |
| input_ids=inputs["input_ids"].to(device), | |
| attention_mask=inputs["attention_mask"].to(device), | |
| length_penalty=0.8, | |
| num_beams=8, | |
| max_length=128, | |
| ) | |
| # Décode les textes | |
| # renplacer les tokens, ajouter des textes décodés avec les rédéfences | |
| # vers la métrique. | |
| decoded_summaries = [ | |
| tokenizer.decode( | |
| s, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
| ) | |
| for s in summaries | |
| ] | |
| decoded_summaries = [d.replace("", " ") for d in decoded_summaries] | |
| metric.add_batch( | |
| predictions=decoded_summaries, | |
| references=target_batch) | |
| # compute et return les ROUGE scores. | |
| results = metric.compute() | |
| rouge_names = ["rouge1", "rouge2", "rougeL", "rougeLsum"] | |
| rouge_dict = dict((rn, results[rn]) for rn in rouge_names) | |
| return pd.DataFrame(rouge_dict, index=["T5"]) | |
| def convert_ex_to_features(example_batch): | |
| """this fonction takes for input a list of inputExemples and convert to InputFeatures""" | |
| input_encodings = tokenizer(example_batch['text'], | |
| max_length=1024, truncation=True) | |
| labels = tokenizer( | |
| example_batch["summary"], | |
| max_length=128, | |
| truncation=True) | |
| return { | |
| "input_ids": input_encodings["input_ids"], | |
| "attention_mask": input_encodings["attention_mask"], | |
| "labels": labels["input_ids"], | |
| } | |
| if __name__ == '__main__': | |
| # réalisation des datasets propres | |
| train_dataset = datasetmaker('data/train_extract.jsonl') | |
| <<<<<<< HEAD | |
| dev_dataset = datasetmaker("data/dev_extract.jsonl") | |
| ======= | |
| test_dataset = datasetmaker("data/test_extract.jsonl") | |
| >>>>>>> 4e410f4bdcd6de645d9e73bb207d8a9170dfc3e1 | |
| test_dataset = datasetmaker('data/test_extract.jsonl') | |
| dataset = datasets.DatasetDict({'train': train_dataset, | |
| 'dev': dev_dataset, 'test': test_dataset}) | |
| # définition de device | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # faire appel au model à entrainer | |
| hf_token = "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" | |
| tokenizer = AutoTokenizer.from_pretrained('google/mt5-small', use_auth_token=hf_token ) | |
| mt5_config = AutoConfig.from_pretrained( | |
| "google/mt5-small", | |
| max_length=128, | |
| length_penalty=0.6, | |
| no_repeat_ngram_size=2, | |
| num_beams=15, | |
| use_auth_token=hf_token | |
| ) | |
| model = (AutoModelForSeq2SeqLM | |
| .from_pretrained('google/mt5-small', config=mt5_config) | |
| .to(device)) | |
| #convertir les exemples en inputFeatures | |
| dataset_pt = dataset.map( | |
| convert_ex_to_features, | |
| remove_columns=["summary", "text"], | |
| batched=True, | |
| batch_size=128, | |
| ) | |
| data_collator = DataCollatorForSeq2Seq( | |
| tokenizer, model=model, return_tensors="pt") | |
| #définir les paramètres d'entrainement(fine tuning) | |
| training_args = Seq2SeqTrainingArguments( | |
| output_dir="t5_summary", | |
| log_level="error", | |
| num_train_epochs=10, | |
| learning_rate=5e-4, | |
| warmup_steps=0, | |
| optim="adafactor", | |
| weight_decay=0.01, | |
| per_device_train_batch_size=2, | |
| per_device_eval_batch_size=1, | |
| gradient_accumulation_steps=16, | |
| evaluation_strategy="steps", | |
| eval_steps=100, | |
| predict_with_generate=True, | |
| generation_max_length=128, | |
| save_steps=500, | |
| logging_steps=10, | |
| # push_to_hub = True | |
| ) | |
| #donner au entraineur(trainer) le model | |
| # et les éléments nécessaire pour l'entrainement | |
| trainer = Seq2SeqTrainer( | |
| model=model, | |
| args=training_args, | |
| data_collator=data_collator, | |
| # compute_metrics = calculate_metric, | |
| train_dataset=dataset_pt["train"], | |
| eval_dataset=dataset_pt["dev"].select(range(10)), | |
| tokenizer=tokenizer, | |
| ) | |
| trainer.train() | |
| rouge_metric = evaluate.load("rouge") | |
| #évluer ensuite le model selon les résultats d'entrainement | |
| score = calculate_metric( | |
| test_dataset, | |
| rouge_metric, | |
| trainer.model, | |
| tokenizer, | |
| batch_size=2, | |
| device=device, | |
| column_text="text", | |
| column_summary="summary", | |
| ) | |
| print(score) | |
| # Fine Tuning terminés et à sauvgarder | |
| # sauvegarder fine-tuned model à local | |
| os.makedirs("t5_summary", exist_ok=True) | |
| if hasattr(trainer.model, "module"): | |
| trainer.model.module.save_pretrained("t5_summary") | |
| else: | |
| trainer.model.save_pretrained("t5_summary") | |
| tokenizer.save_pretrained("t5_summary") | |
| # faire appel au model en local | |
| model = (AutoModelForSeq2SeqLM | |
| .from_pretrained("t5_summary", use_auth_token=hf_token ) | |
| .to(device)) | |
| # mettre en usage : TEST | |
| # gen_kwargs = {"length_penalty" : 0.8, "num_beams" : 8, "max_length" : 128} | |
| # sample_text = dataset["test"][0]["text"] | |
| # reference = dataset["test"][0]["summary"] | |
| # pipe = pipeline("summarization", model='./summarization_t5') | |
| # print("Text :") | |
| # print(sample_text) | |
| # print("\nReference Summary :") | |
| # print(reference) | |
| # print("\nModel Summary :") | |
| # print(pipe(sample_text, **gen_kwargs)[0]["summary_text"]) | |