File size: 3,246 Bytes
d0c3c53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
"""
Example script showing how to use the Hello World model with its dataset.
"""
from transformers import PreTrainedTokenizerFast
from model import HelloWorldModel, HelloWorldConfig
from datasets import load_dataset
import torch
def main():
print("Loading Hello World Model and Dataset Example\n")
print("=" * 50)
# Load model and tokenizer
print("Loading model and tokenizer...")
config = HelloWorldConfig.from_pretrained("chiedo/hello-world")
model = HelloWorldModel.from_pretrained("chiedo/hello-world")
tokenizer = PreTrainedTokenizerFast.from_pretrained("chiedo/hello-world")
# Method 1: Load dataset using the model's built-in method
print("\n1. Loading dataset using model's load_dataset method:")
dataset = HelloWorldModel.load_dataset("chiedo/hello-world")
if dataset:
print(f"Dataset loaded successfully!")
print(f"Splits available: {list(dataset.keys())}")
print(f"Train examples: {len(dataset['train'])}")
print(f"Validation examples: {len(dataset['validation'])}")
print(f"Test examples: {len(dataset['test'])}")
# Show first few examples
print("\nFirst 3 training examples:")
for i in range(min(3, len(dataset['train']))):
example = dataset['train'][i]
print(f" {i+1}. Text: '{example['text']}', Label: {example['label']}")
# Method 2: Load dataset directly
print("\n2. Loading dataset directly with datasets library:")
dataset_direct = load_dataset("chiedo/hello-world")
# Get label names
label_names = dataset_direct['train'].features['label'].names
print(f"Label categories: {label_names}")
# Process a batch from the dataset
print("\n3. Processing a batch from the dataset:")
batch_texts = dataset_direct['train']['text'][:3]
print(f"Batch texts: {batch_texts}")
# Prepare batch for model
inputs = model.prepare_dataset_batch(batch_texts, tokenizer)
print(f"Tokenized input shape: {inputs['input_ids'].shape}")
# Run model inference
print("\n4. Running model inference on dataset batch:")
with torch.no_grad():
outputs = model(**inputs)
print(f"Model output shape: {outputs.logits.shape}")
# Demonstrate the generate_hello_world function
print("\n5. Testing generate_hello_world function:")
result = model.generate_hello_world()
print(f"Generated output: {result}")
# Show how to iterate through dataset
print("\n6. Iterating through test set:")
for i, example in enumerate(dataset_direct['test']):
if i >= 3: # Only show first 3
break
text = example['text']
label_id = example['label']
label_name = label_names[label_id]
# Tokenize and process
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_token = outputs.logits[0, -1].argmax().item()
print(f" Text: '{text}' | Label: {label_name} | Predicted next token ID: {predicted_token}")
print("\n" + "=" * 50)
print("Example completed successfully!")
if __name__ == "__main__":
main() |