Omarito101's picture
Upload trained Energy Document Classifier
d54d478 verified
"""
Example: Using the Energy Document Classifier
This script demonstrates how to use the model for classifying documents.
"""
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
def main():
# Load model and tokenizer
model_name = "EnergyAI/Llama-3.1-8B-Energy-Classifier" # Change to your model
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.eval()
print("Model loaded!\n")
# Example texts
texts = [
"Solar panel installations have increased by 40% this year, driven by government incentives and falling prices.",
"The software development team completed the sprint planning session and assigned tasks for the next iteration.",
"OPEC announced a production cut of 2 million barrels per day, causing oil prices to surge on global markets.",
"The training program for new employees will begin next Monday and continue for three weeks.",
]
# Classify each text
label_map = {0: "non_energy", 1: "energy"}
for i, text in enumerate(texts, 1):
print(f"\n{'='*70}")
print(f"Example {i}:")
print(f"Text: {text}")
print(f"{'-'*70}")
# Tokenize and prepare input
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True,
)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Get prediction
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class = torch.argmax(probs, dim=-1).item()
confidence = probs[0][predicted_class].item()
# Print results
print(f"Prediction: {label_map[predicted_class].upper()}")
print(f"Confidence: {confidence:.4f}")
print(f"Probabilities:")
print(f" - Non-Energy: {probs[0][0].item():.4f}")
print(f" - Energy: {probs[0][1].item():.4f}")
print(f"\n{'='*70}\n")
if __name__ == "__main__":
main()