|
|
""" |
|
|
Example: Using the Energy Document Classifier |
|
|
|
|
|
This script demonstrates how to use the model for classifying documents. |
|
|
""" |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
|
|
|
def main(): |
|
|
|
|
|
model_name = "EnergyAI/Llama-3.1-8B-Energy-Classifier" |
|
|
|
|
|
print("Loading model...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
model_name, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
model.eval() |
|
|
print("Model loaded!\n") |
|
|
|
|
|
|
|
|
texts = [ |
|
|
"Solar panel installations have increased by 40% this year, driven by government incentives and falling prices.", |
|
|
"The software development team completed the sprint planning session and assigned tasks for the next iteration.", |
|
|
"OPEC announced a production cut of 2 million barrels per day, causing oil prices to surge on global markets.", |
|
|
"The training program for new employees will begin next Monday and continue for three weeks.", |
|
|
] |
|
|
|
|
|
|
|
|
label_map = {0: "non_energy", 1: "energy"} |
|
|
|
|
|
for i, text in enumerate(texts, 1): |
|
|
print(f"\n{'='*70}") |
|
|
print(f"Example {i}:") |
|
|
print(f"Text: {text}") |
|
|
print(f"{'-'*70}") |
|
|
|
|
|
|
|
|
inputs = tokenizer( |
|
|
text, |
|
|
return_tensors="pt", |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
padding=True, |
|
|
) |
|
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) |
|
|
predicted_class = torch.argmax(probs, dim=-1).item() |
|
|
confidence = probs[0][predicted_class].item() |
|
|
|
|
|
|
|
|
print(f"Prediction: {label_map[predicted_class].upper()}") |
|
|
print(f"Confidence: {confidence:.4f}") |
|
|
print(f"Probabilities:") |
|
|
print(f" - Non-Energy: {probs[0][0].item():.4f}") |
|
|
print(f" - Energy: {probs[0][1].item():.4f}") |
|
|
|
|
|
print(f"\n{'='*70}\n") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|