rwcuffney's picture
Update app.py
1ffb024
raw
history blame
1.13 kB
import streamlit as st
from datasets import load_dataset
dataset = load_dataset('rwcuffney/pick_a_card_test', batch_size=32, shuffle=True)
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained('rwcuffney/autotrain-pick_a_card-3726099224')
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('rwcuffney/autotrain-pick_a_card-3726099224')
def preprocess_text(text):
encoded = tokenizer(text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
return encoded
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
for batch in dataset:
# Preprocess the text
text = batch['text']
inputs = preprocess_text(text)
inputs = inputs.to(device)
# Make predictions
with torch.no_grad():
outputs = model(**inputs)
predicted_classes = torch.argmax(outputs.logits, dim=-1)
# Print the predicted class labels
predicted_labels = [dataset.features['label'].names[i] for i in predicted_classes]
st.write(predicted_labels)