File size: 2,487 Bytes
c2ea27d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
pipeline
)
import logging
logger = logging.getLogger(__name__)
class AcoliModel:
def __init__(self, model_path=None):
self.model_path = model_path
self.tokenizer = None
self.model = None
self.classifier = None
if model_path:
self.load_model(model_path)
def load_model(self, model_path):
"""Load a trained model"""
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.classifier = pipeline(
"text-classification",
model=self.model,
tokenizer=self.tokenizer
)
logger.info(f"Model loaded successfully from {model_path}")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
def predict(self, text):
"""Make prediction on input text"""
if self.classifier is None:
raise ValueError("Model not loaded. Call load_model() first.")
return self.classifier(text)
def predict_batch(self, texts):
"""Make predictions on multiple texts"""
if self.classifier is None:
raise ValueError("Model not loaded. Call load_model() first.")
return [self.classifier(text) for text in texts]
def get_model_info(self):
"""Get model information"""
if self.model is None:
return "Model not loaded"
return {
"model_type": type(self.model).__name__,
"num_parameters": sum(p.numel() for p in self.model.parameters()),
"model_path": self.model_path
}
# Example usage and convenience functions
def load_acoli_model(model_path="./acoli-model"):
"""Convenience function to load the Acoli model"""
return AcoliModel(model_path)
def create_training_instance():
"""Create a training instance"""
from Train import AcoliTrainer
return AcoliTrainer()
if __name__ == "__main__":
# Example usage
model = AcoliModel()
# After training, you can load the model like this:
# model.load_model("./acoli-model")
# prediction = model.predict("Your Acoli text here")
# print(prediction)
print("Acoli model class ready for use!") |