| import logging |
| import os |
| import pandas as pd |
| import random |
| import re |
| import sys |
| import time |
| from dataclasses import dataclass, field |
| from functools import partial |
| from pathlib import Path |
| from typing import Callable, Optional |
|
|
| import jax |
| import jax.numpy as jnp |
|
|
| from filelock import FileLock |
| from flax import jax_utils, traverse_util |
| from flax.jax_utils import unreplicate |
| from flax.training import train_state |
| from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key |
|
|
| from transformers import FlaxAutoModelForSeq2SeqLM |
| from transformers import AutoTokenizer |
|
|
| from datasets import Dataset, load_dataset, load_metric |
| from tqdm import tqdm |
| import pandas as pd |
|
|
|
|
| print(jax.devices()) |
|
|
| MODEL_NAME_OR_PATH = "../" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH, use_fast=True) |
| model = FlaxAutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH) |
|
|
| prefix = "items: " |
| text_column = "inputs" |
| target_column = "targets" |
| max_source_length = 256 |
| max_target_length = 1024 |
| seed = 42 |
| eval_batch_size = 64 |
| |
| |
| |
| |
| |
| |
| |
| |
| generation_kwargs = { |
| "max_length": 1024, |
| "min_length": 64, |
| "no_repeat_ngram_size": 3, |
| "early_stopping": True, |
| "num_beams": 4, |
| "length_penalty": 1.5, |
| } |
|
|
| special_tokens = tokenizer.all_special_tokens |
| tokens_map = { |
| "<sep>": "--", |
| "<section>": "\n" |
| } |
| def skip_special_tokens(text, special_tokens): |
| for token in special_tokens: |
| text = text.replace(token, '') |
|
|
| return text |
|
|
| def target_postprocessing(texts, special_tokens): |
| if not isinstance(texts, list): |
| texts = [texts] |
| |
| new_texts = [] |
| for text in texts: |
| text = skip_special_tokens(text, special_tokens) |
|
|
| for k, v in tokens_map.items(): |
| text = text.replace(k, v) |
|
|
| new_texts.append(text) |
|
|
| return new_texts |
|
|
|
|
| predict_dataset = load_dataset("csv", data_files={"test": "/home/m3hrdadfi/code/data/test.csv"}, delimiter="\t")["test"] |
| print(predict_dataset) |
| |
| |
| column_names = predict_dataset.column_names |
| print(column_names) |
|
|
|
|
| |
| def preprocess_function(examples): |
| inputs = examples[text_column] |
| targets = examples[target_column] |
| inputs = [prefix + inp for inp in inputs] |
| model_inputs = tokenizer( |
| inputs, |
| max_length=max_source_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="np" |
| ) |
|
|
| |
| with tokenizer.as_target_tokenizer(): |
| labels = tokenizer( |
| targets, |
| max_length=max_target_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="np" |
| ) |
|
|
| model_inputs["labels"] = labels["input_ids"] |
|
|
| return model_inputs |
|
|
| predict_dataset = predict_dataset.map( |
| preprocess_function, |
| batched=True, |
| num_proc=None, |
| remove_columns=column_names, |
| desc="Running tokenizer on prediction dataset", |
| ) |
|
|
| def data_loader(rng: jax.random.PRNGKey, dataset: Dataset, batch_size: int, shuffle: bool = False): |
| """ |
| Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. |
| Shuffle batches if `shuffle` is `True`. |
| """ |
| steps_per_epoch = len(dataset) // batch_size |
|
|
| if shuffle: |
| batch_idx = jax.random.permutation(rng, len(dataset)) |
| else: |
| batch_idx = jnp.arange(len(dataset)) |
|
|
| batch_idx = batch_idx[: steps_per_epoch * batch_size] |
| batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) |
|
|
| for idx in batch_idx: |
| batch = dataset[idx] |
| batch = {k: jnp.array(v) for k, v in batch.items()} |
|
|
| batch = shard(batch) |
|
|
| yield batch |
|
|
| rng = jax.random.PRNGKey(seed) |
| rng, dropout_rng = jax.random.split(rng) |
| rng, input_rng = jax.random.split(rng) |
|
|
| def generate_step(batch): |
| output_ids = model.generate(batch["input_ids"], attention_mask=batch["attention_mask"], **generation_kwargs) |
| return output_ids.sequences |
|
|
| p_generate_step = jax.pmap(generate_step, "batch") |
|
|
| pred_generations = [] |
| pred_labels = [] |
| pred_inputs = [] |
| pred_loader = data_loader(input_rng, predict_dataset, eval_batch_size) |
| pred_steps = len(predict_dataset) // eval_batch_size |
|
|
| for _ in tqdm(range(pred_steps), desc="Predicting...", position=2, leave=False): |
| |
| batch = next(pred_loader) |
| inputs = batch["input_ids"] |
| labels = batch["labels"] |
|
|
| generated_ids = p_generate_step(batch) |
| pred_generations.extend(jax.device_get(generated_ids.reshape(-1, generation_kwargs["max_length"]))) |
| pred_labels.extend(jax.device_get(labels.reshape(-1, labels.shape[-1]))) |
| pred_inputs.extend(jax.device_get(inputs.reshape(-1, inputs.shape[-1]))) |
|
|
| inputs = tokenizer.batch_decode(pred_inputs, skip_special_tokens=True) |
| true_recipe = target_postprocessing( |
| tokenizer.batch_decode(pred_labels, skip_special_tokens=False), |
| special_tokens |
| ) |
| generated_recipe = target_postprocessing( |
| tokenizer.batch_decode(pred_generations, skip_special_tokens=False), |
| special_tokens |
| ) |
| test_output = { |
| "inputs": inputs, |
| "true_recipe": true_recipe, |
| "generated_recipe": generated_recipe |
| } |
| test_output = pd.DataFrame.from_dict(test_output) |
| test_output.to_csv("./generated_recipes_b.csv", sep="\t", index=False, encoding="utf-8") |
|
|