Acoli / model.py
prelington's picture
Create model.py
c2ea27d verified
import torch
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
pipeline
)
import logging
logger = logging.getLogger(__name__)
class AcoliModel:
def __init__(self, model_path=None):
self.model_path = model_path
self.tokenizer = None
self.model = None
self.classifier = None
if model_path:
self.load_model(model_path)
def load_model(self, model_path):
"""Load a trained model"""
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
self.classifier = pipeline(
"text-classification",
model=self.model,
tokenizer=self.tokenizer
)
logger.info(f"Model loaded successfully from {model_path}")
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
def predict(self, text):
"""Make prediction on input text"""
if self.classifier is None:
raise ValueError("Model not loaded. Call load_model() first.")
return self.classifier(text)
def predict_batch(self, texts):
"""Make predictions on multiple texts"""
if self.classifier is None:
raise ValueError("Model not loaded. Call load_model() first.")
return [self.classifier(text) for text in texts]
def get_model_info(self):
"""Get model information"""
if self.model is None:
return "Model not loaded"
return {
"model_type": type(self.model).__name__,
"num_parameters": sum(p.numel() for p in self.model.parameters()),
"model_path": self.model_path
}
# Example usage and convenience functions
def load_acoli_model(model_path="./acoli-model"):
"""Convenience function to load the Acoli model"""
return AcoliModel(model_path)
def create_training_instance():
"""Create a training instance"""
from Train import AcoliTrainer
return AcoliTrainer()
if __name__ == "__main__":
# Example usage
model = AcoliModel()
# After training, you can load the model like this:
# model.load_model("./acoli-model")
# prediction = model.predict("Your Acoli text here")
# print(prediction)
print("Acoli model class ready for use!")