| <<<<<<< HEAD |
|
|
|
|
|
|
| import tensorflow as tf
|
| import os
|
| from transformers import TFMT5ForConditionalGeneration, MT5Tokenizer
|
| import pandas as pd
|
| from datasets import Dataset
|
| from tqdm import tqdm
|
|
|
| tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def load_model(checkpoint_dir):
|
| latest_checkpoint = None
|
| if os.path.exists(checkpoint_dir):
|
| checkpoints = [os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir)]
|
| checkpoints = [d for d in checkpoints if os.path.isdir(d)]
|
| if checkpoints:
|
| latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
|
|
| if latest_checkpoint:
|
| print("Loading model from:", latest_checkpoint)
|
| return TFMT5ForConditionalGeneration.from_pretrained(latest_checkpoint)
|
| else:
|
| print("No checkpoint found, loading default model")
|
| return TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
|
|
|
|
|
| model = load_model('model_checkpoints-small-on-1mill-dp')
|
|
|
|
|
| def prepare_text(text, tokenizer, max_length=200):
|
| inputs = tokenizer.encode(text, return_tensors="tf", max_length=max_length, truncation=True)
|
| return inputs
|
|
|
|
|
| def generate_prediction(text, model, tokenizer, max_length=2000, num_beams=5):
|
| input_ids = prepare_text(text, tokenizer, max_length=max_length)
|
| output_ids = model.generate(
|
| input_ids,
|
| max_length=max_length,
|
| num_beams=num_beams,
|
| no_repeat_ngram_size=2,
|
| early_stopping=True
|
| )
|
| return tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
|
|
|
|
|
|
| sample_text = "Hi , how are you?"
|
|
|
| prediction = generate_prediction(sample_text, model, tokenizer)
|
| print("Prediction:", prediction)
|
| ======= |
| |
| |
|
|
| import tensorflow as tf |
| import os |
| from transformers import TFMT5ForConditionalGeneration, MT5Tokenizer |
| import pandas as pd |
| from datasets import Dataset |
| from tqdm import tqdm |
| |
| tokenizer = MT5Tokenizer.from_pretrained("google/mt5-small") |
|
|
|
|
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
|
|
| |
| def load_model(checkpoint_dir): |
| latest_checkpoint = None |
| if os.path.exists(checkpoint_dir): |
| checkpoints = [os.path.join(checkpoint_dir, d) for d in os.listdir(checkpoint_dir)] |
| checkpoints = [d for d in checkpoints if os.path.isdir(d)] |
| if checkpoints: |
| latest_checkpoint = max(checkpoints, key=os.path.getmtime) |
|
|
| if latest_checkpoint: |
| print("Loading model from:", latest_checkpoint) |
| return TFMT5ForConditionalGeneration.from_pretrained(latest_checkpoint) |
| else: |
| print("No checkpoint found, loading default model") |
| return TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small") |
|
|
| |
| model = load_model('model_checkpoints-small-on-1mill-dp') |
|
|
| |
| def prepare_text(text, tokenizer, max_length=200): |
| inputs = tokenizer.encode(text, return_tensors="tf", max_length=max_length, truncation=True) |
| return inputs |
|
|
| |
| def generate_prediction(text, model, tokenizer, max_length=2000, num_beams=5): |
| input_ids = prepare_text(text, tokenizer, max_length=max_length) |
| output_ids = model.generate( |
| input_ids, |
| max_length=max_length, |
| num_beams=num_beams, |
| no_repeat_ngram_size=2, |
| early_stopping=True |
| ) |
| return tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
|
| |
| sample_text = "Hi , how are you?" |
| |
| prediction = generate_prediction(sample_text, model, tokenizer) |
| print("Prediction:", prediction) |
| >>>>>>> origin/main |
|
|