| | import logging |
| | import os |
| | import pandas as pd |
| | import random |
| | import re |
| | import sys |
| | import time |
| | from dataclasses import dataclass, field |
| | from functools import partial |
| | from pathlib import Path |
| | from typing import Callable, Optional |
| |
|
| | import jax |
| | import jax.numpy as jnp |
| |
|
| | from filelock import FileLock |
| | from flax import jax_utils, traverse_util |
| | from flax.jax_utils import unreplicate |
| | from flax.training import train_state |
| | from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key |
| |
|
| | from transformers import FlaxAutoModelForSeq2SeqLM |
| | from transformers import AutoTokenizer |
| |
|
| | from datasets import Dataset, load_dataset, load_metric |
| | from tqdm import tqdm |
| | import pandas as pd |
| |
|
| |
|
| | print(jax.devices()) |
| |
|
| | MODEL_NAME_OR_PATH = "../" |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) |
| | model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) |
| |
|
| | prefix = "items: " |
| | text_column = "inputs" |
| | target_column = "targets" |
| | max_source_length = 256 |
| | max_target_length = 1024 |
| | seed = 42 |
| | eval_batch_size = 64 |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | generation_kwargs = { |
| | "max_length": 1024, |
| | "min_length": 64, |
| | "no_repeat_ngram_size": 3, |
| | "early_stopping": True, |
| | "num_beams": 4, |
| | "length_penalty": 1.5, |
| | } |
| |
|
| | special_tokens = tokenizer.all_special_tokens |
| | tokens_map = { |
| | "<sep>": "--", |
| | "<section>": "\n" |
| | } |
| | def skip_special_tokens(text, special_tokens): |
| | for token in special_tokens: |
| | text = text.replace(token, '') |
| |
|
| | return text |
| |
|
| | def target_postprocessing(texts, special_tokens): |
| | if not isinstance(texts, list): |
| | texts = [texts] |
| | |
| | new_texts = [] |
| | for text in texts: |
| | text = skip_special_tokens(text, special_tokens) |
| |
|
| | for k, v in tokens_map.items(): |
| | text = text.replace(k, v) |
| |
|
| | new_texts.append(text) |
| |
|
| | return new_texts |
| |
|
| |
|
| | predict_dataset = load_dataset("csv", data_files={"test": "/home/m3hrdadfi/code/data/test.csv"}, delimiter="\t")["test"] |
| | print(predict_dataset) |
| | |
| | |
| | column_names = predict_dataset.column_names |
| | print(column_names) |
| |
|
| |
|
| | |
| | def preprocess_function(examples): |
| | inputs = examples[text_column] |
| | targets = examples[target_column] |
| | inputs = [prefix + inp for inp in inputs] |
| | model_inputs = tokenizer( |
| | inputs, |
| | max_length=max_source_length, |
| | padding="max_length", |
| | truncation=True, |
| | return_tensors="np" |
| | ) |
| |
|
| | |
| | with tokenizer.as_target_tokenizer(): |
| | labels = tokenizer( |
| | targets, |
| | max_length=max_target_length, |
| | padding="max_length", |
| | truncation=True, |
| | return_tensors="np" |
| | ) |
| |
|
| | model_inputs["labels"] = labels["input_ids"] |
| |
|
| | return model_inputs |
| |
|
| | predict_dataset = predict_dataset.map( |
| | preprocess_function, |
| | batched=True, |
| | num_proc=None, |
| | remove_columns=column_names, |
| | desc="Running tokenizer on prediction dataset", |
| | ) |
| |
|
| | def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False): |
| | """ |
| | Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. |
| | Shuffle batches if `shuffle` is `True`. |
| | """ |
| | steps_per_epoch = len(dataset) // batch_size |
| |
|
| | if shuffle: |
| | batch_idx = jax.random.permutation(rng, len(dataset)) |
| | else: |
| | batch_idx = jnp.arange(len(dataset)) |
| |
|
| | batch_idx = batch_idx[: steps_per_epoch * batch_size] |
| | batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) |
| |
|
| | for idx in batch_idx: |
| | batch = dataset[idx] |
| | batch = {k: jnp.array(v) for k, v in batch.items()} |
| |
|
| | batch = shard(batch) |
| |
|
| | yield batch |
| |
|
| | rng = jax.random.PRNGKey(seed) |
| | rng, dropout_rng = jax.random.split(rng) |
| | rng, input_rng = jax.random.split(rng) |
| |
|
| | def generate_step(batch): |
| | output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **generation_kwargs) |
| | return output_ids.sequences |
| |
|
| | p_generate_step = jax.pmap(generate_step, "batch") |
| |
|
| | pred_generations = [] |
| | pred_labels = [] |
| | pred_inputs = [] |
| | pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size) |
| | pred_steps = len(predict_dataset) // eval_batch_size |
| |
|
| | for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False): |
| | |
| | batch = next(pred_loader) |
| | inputs = batch["input_ids"] |
| | labels = batch["labels"] |
| |
|
| | generated_ids = p_generate_step(batch) |
| | pred_generations.extend(jax.device_get(generated_ids.reshape(-1, generation_kwargs["max_length"]))) |
| | pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) |
| | pred_inputs.extend(jax.device_get(inputs.reshape(-1, inputs.shape[-1]))) |
| |
|
| | inputs = tokenizer.batch_decode(pred_inputs, skip_special_tokens=True) |
| | true_recipe = target_postprocessing( |
| | tokenizer.batch_decode(pred_labels, skip_special_tokens=False), |
| | special_tokens |
| | ) |
| | generated_recipe = target_postprocessing( |
| | tokenizer.batch_decode(pred_generations, skip_special_tokens=False), |
| | special_tokens |
| | ) |
| | test_output = { |
| | "inputs": inputs, |
| | "true_recipe": true_recipe, |
| | "generated_recipe": generated_recipe |
| | } |
| | test_output = pd.DataFrame.from_dict(test_output) |
| | test_output.to_csv("./generated_recipes_b.csv", sep="\t", index=False, encoding="utf-8") |
| |
|