|
|
import torch |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForSequenceClassification, |
|
|
pipeline |
|
|
) |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class AcoliModel: |
|
|
def __init__(self, model_path=None): |
|
|
self.model_path = model_path |
|
|
self.tokenizer = None |
|
|
self.model = None |
|
|
self.classifier = None |
|
|
|
|
|
if model_path: |
|
|
self.load_model(model_path) |
|
|
|
|
|
def load_model(self, model_path): |
|
|
"""Load a trained model""" |
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(model_path) |
|
|
self.classifier = pipeline( |
|
|
"text-classification", |
|
|
model=self.model, |
|
|
tokenizer=self.tokenizer |
|
|
) |
|
|
logger.info(f"Model loaded successfully from {model_path}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model: {e}") |
|
|
raise |
|
|
|
|
|
def predict(self, text): |
|
|
"""Make prediction on input text""" |
|
|
if self.classifier is None: |
|
|
raise ValueError("Model not loaded. Call load_model() first.") |
|
|
|
|
|
return self.classifier(text) |
|
|
|
|
|
def predict_batch(self, texts): |
|
|
"""Make predictions on multiple texts""" |
|
|
if self.classifier is None: |
|
|
raise ValueError("Model not loaded. Call load_model() first.") |
|
|
|
|
|
return [self.classifier(text) for text in texts] |
|
|
|
|
|
def get_model_info(self): |
|
|
"""Get model information""" |
|
|
if self.model is None: |
|
|
return "Model not loaded" |
|
|
|
|
|
return { |
|
|
"model_type": type(self.model).__name__, |
|
|
"num_parameters": sum(p.numel() for p in self.model.parameters()), |
|
|
"model_path": self.model_path |
|
|
} |
|
|
|
|
|
|
|
|
def load_acoli_model(model_path="./acoli-model"): |
|
|
"""Convenience function to load the Acoli model""" |
|
|
return AcoliModel(model_path) |
|
|
|
|
|
def create_training_instance(): |
|
|
"""Create a training instance""" |
|
|
from Train import AcoliTrainer |
|
|
return AcoliTrainer() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
model = AcoliModel() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Acoli model class ready for use!") |