Instructions to use samyak152002/Tweet_Abortion_Analysis with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use samyak152002/Tweet_Abortion_Analysis with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="samyak152002/Tweet_Abortion_Analysis")# Load model directly from transformers import AutoTokenizer, AutoModel tokenizer = AutoTokenizer.from_pretrained("samyak152002/Tweet_Abortion_Analysis") model = AutoModel.from_pretrained("samyak152002/Tweet_Abortion_Analysis") - Notebooks
- Google Colab
- Kaggle
File size: 1,249 Bytes
02fd376 a7fc3b1 02fd376 a7fc3b1 02fd376 a7fc3b1 02fd376 a7fc3b1 5a3da59 02fd376 a7fc3b1 02fd376 a7fc3b1 02fd376 a7fc3b1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | import torch
from transformers import DistilBertModel, DistilBertTokenizer
# Load the tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertCNN(num_labels=3) # Assuming you have defined the custom classification layers
# Move the model to CPU
device = torch.device("cpu")
model.to(device)
# Load the saved model state dictionary
model.load_state_dict(torch.load("model.pt", map_location=device))
# Set the model to evaluation mode
model.eval()
# Define a function to predict the class of a given tweet
def classify_tweet(tweet):
inputs = tokenizer.encode_plus(
tweet,
add_special_tokens=True,
max_length=128,
padding="max_length",
truncation=True,
return_tensors="pt"
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs[0]
predicted_class = torch.argmax(logits).item()
return predicted_class
# Example usage
tweet = "This is a sample tweet."
predicted_class = classify_tweet(tweet)
print(f"Predicted Class: {predicted_class}")
|