prelington commited on
Commit
c2ea27d
·
verified ·
1 Parent(s): 5089144

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +80 -0
model.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import (
3
+ AutoTokenizer,
4
+ AutoModelForSequenceClassification,
5
+ pipeline
6
+ )
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class AcoliModel:
12
+ def __init__(self, model_path=None):
13
+ self.model_path = model_path
14
+ self.tokenizer = None
15
+ self.model = None
16
+ self.classifier = None
17
+
18
+ if model_path:
19
+ self.load_model(model_path)
20
+
21
+ def load_model(self, model_path):
22
+ """Load a trained model"""
23
+ try:
24
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
25
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
26
+ self.classifier = pipeline(
27
+ "text-classification",
28
+ model=self.model,
29
+ tokenizer=self.tokenizer
30
+ )
31
+ logger.info(f"Model loaded successfully from {model_path}")
32
+ except Exception as e:
33
+ logger.error(f"Error loading model: {e}")
34
+ raise
35
+
36
+ def predict(self, text):
37
+ """Make prediction on input text"""
38
+ if self.classifier is None:
39
+ raise ValueError("Model not loaded. Call load_model() first.")
40
+
41
+ return self.classifier(text)
42
+
43
+ def predict_batch(self, texts):
44
+ """Make predictions on multiple texts"""
45
+ if self.classifier is None:
46
+ raise ValueError("Model not loaded. Call load_model() first.")
47
+
48
+ return [self.classifier(text) for text in texts]
49
+
50
+ def get_model_info(self):
51
+ """Get model information"""
52
+ if self.model is None:
53
+ return "Model not loaded"
54
+
55
+ return {
56
+ "model_type": type(self.model).__name__,
57
+ "num_parameters": sum(p.numel() for p in self.model.parameters()),
58
+ "model_path": self.model_path
59
+ }
60
+
61
+ # Example usage and convenience functions
62
+ def load_acoli_model(model_path="./acoli-model"):
63
+ """Convenience function to load the Acoli model"""
64
+ return AcoliModel(model_path)
65
+
66
+ def create_training_instance():
67
+ """Create a training instance"""
68
+ from Train import AcoliTrainer
69
+ return AcoliTrainer()
70
+
71
+ if __name__ == "__main__":
72
+ # Example usage
73
+ model = AcoliModel()
74
+
75
+ # After training, you can load the model like this:
76
+ # model.load_model("./acoli-model")
77
+ # prediction = model.predict("Your Acoli text here")
78
+ # print(prediction)
79
+
80
+ print("Acoli model class ready for use!")