ga-POS-tagger / train_perceptron.py
laurencassidy's picture
Upload 14 files
1170fda verified
"""
Averaged Perceptron POS Tagger for Irish Twitter data.
Trains on TwittIrish CoNLL-U data and saves the model.
"""
import pickle
import json
from collections import defaultdict
from pathlib import Path
def parse_conllu(filepath):
"""Parse CoNLL-U file and extract sentences with POS tags."""
sentences = []
current_sentence = []
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line.startswith('#'):
continue
if not line:
if current_sentence:
sentences.append(current_sentence)
current_sentence = []
continue
parts = line.split('\t')
if len(parts) >= 4 and '-' not in parts[0] and '.' not in parts[0]:
token = parts[1]
pos = parts[3]
current_sentence.append((token, pos))
if current_sentence:
sentences.append(current_sentence)
return sentences
def extract_features(sentence, index, prev_tag=None, prev_prev_tag=None):
"""Extract features for a token at given index."""
token = sentence[index][0]
features = {
'bias': 1.0,
'word': token.lower(),
'word[-3:]': token[-3:].lower() if len(token) >= 3 else token.lower(),
'word[-2:]': token[-2:].lower() if len(token) >= 2 else token.lower(),
'word[-1:]': token[-1:].lower(),
'word[:3]': token[:3].lower() if len(token) >= 3 else token.lower(),
'word[:2]': token[:2].lower() if len(token) >= 2 else token.lower(),
'word[:1]': token[:1].lower(),
'is_upper': token.isupper(),
'is_title': token.istitle(),
'is_digit': token.isdigit(),
'is_alpha': token.isalpha(),
'has_hyphen': '-' in token,
'is_mention': token.startswith('@'),
'is_hashtag': token.startswith('#'),
'is_url': token.startswith('http'),
'is_rt': token.upper() == 'RT',
'length': len(token),
}
# Previous tag features (critical for sequence labeling!)
if prev_tag:
features['prev_tag'] = prev_tag
features[f'prev_tag+word'] = f'{prev_tag}+{token.lower()}'
else:
features['BOS_tag'] = True
if prev_prev_tag:
features['prev_prev_tag'] = prev_prev_tag
features[f'prev_prev_tag+prev_tag'] = f'{prev_prev_tag}+{prev_tag}'
# Previous token features
if index > 0:
prev_token = sentence[index - 1][0]
features['prev_word'] = prev_token.lower()
features['prev_word[-3:]'] = prev_token[-3:].lower() if len(prev_token) >= 3 else prev_token.lower()
if prev_tag:
features[f'prev_word+prev_tag'] = f'{prev_token.lower()}+{prev_tag}'
else:
features['BOS'] = True # Beginning of sentence
# Next token features
if index < len(sentence) - 1:
next_token = sentence[index + 1][0]
features['next_word'] = next_token.lower()
features['next_word[:3]'] = next_token[:3].lower() if len(next_token) >= 3 else next_token.lower()
else:
features['EOS'] = True # End of sentence
return features
class AveragedPerceptron:
"""Averaged Perceptron classifier for sequence labeling."""
def __init__(self):
self.weights = defaultdict(lambda: defaultdict(float))
self._totals = defaultdict(lambda: defaultdict(float))
self._timestamps = defaultdict(lambda: defaultdict(int))
self._i = 0
def predict(self, features):
"""Predict the best label for given features."""
scores = defaultdict(float)
for feat, value in features.items():
if feat not in self.weights:
continue
for label, weight in self.weights[feat].items():
scores[label] += weight * (value if isinstance(value, (int, float)) else 1.0)
if not scores:
return 'NOUN' # Default fallback
return max(scores, key=scores.get)
def update(self, truth, guess, features):
"""Update weights based on prediction error."""
self._i += 1
if truth == guess:
return
for feat in features:
self._update_feat(truth, feat, 1.0)
self._update_feat(guess, feat, -1.0)
def _update_feat(self, label, feat, value):
"""Update a single feature weight."""
param_key = (feat, label)
curr_weight = self.weights[feat][label]
# Update totals for averaging
self._totals[feat][label] += (self._i - self._timestamps[feat][label]) * curr_weight
self._timestamps[feat][label] = self._i
# Update weight
self.weights[feat][label] = curr_weight + value
def average_weights(self):
"""Average the weights over all updates."""
for feat, label_weights in self.weights.items():
for label in label_weights:
total = self._totals[feat][label]
total += (self._i - self._timestamps[feat][label]) * self.weights[feat][label]
averaged = total / self._i if self._i else 0
self.weights[feat][label] = averaged
# Convert to regular dicts for serialization
self.weights = {k: dict(v) for k, v in self.weights.items()}
class PerceptronTagger:
"""POS Tagger using Averaged Perceptron."""
def __init__(self):
self.model = AveragedPerceptron()
self.tagset = set()
def train(self, sentences, n_epochs=5):
"""Train the tagger on annotated sentences."""
print(f"Training on {len(sentences)} sentences for {n_epochs} epochs...")
for epoch in range(n_epochs):
correct = 0
total = 0
# Shuffle sentences each epoch for better training
import random
shuffled = sentences.copy()
random.shuffle(shuffled)
for sentence in shuffled:
prev_tag = None
prev_prev_tag = None
for i, (token, true_tag) in enumerate(sentence):
features = extract_features(sentence, i, prev_tag, prev_prev_tag)
guess = self.model.predict(features)
self.model.update(true_tag, guess, features)
self.tagset.add(true_tag)
correct += (guess == true_tag)
total += 1
# Update tag history with TRUE tag (teacher forcing)
prev_prev_tag = prev_tag
prev_tag = true_tag
acc = correct / total if total else 0
print(f" Epoch {epoch + 1}/{n_epochs}: accuracy = {acc:.4f}")
self.model.average_weights()
print("Training complete. Weights averaged.")
def tag(self, tokens):
"""Tag a list of tokens."""
sentence = [(t, None) for t in tokens]
tags = []
prev_tag = None
prev_prev_tag = None
for i in range(len(tokens)):
features = extract_features(sentence, i, prev_tag, prev_prev_tag)
tag = self.model.predict(features)
tags.append(tag)
sentence[i] = (tokens[i], tag)
# Update tag history with PREDICTED tag
prev_prev_tag = prev_tag
prev_tag = tag
return list(zip(tokens, tags))
def evaluate(self, sentences):
"""Evaluate on test sentences."""
correct = 0
total = 0
for sentence in sentences:
tokens = [t for t, _ in sentence]
true_tags = [tag for _, tag in sentence]
predicted = self.tag(tokens)
pred_tags = [tag for _, tag in predicted]
for true, pred in zip(true_tags, pred_tags):
correct += (true == pred)
total += 1
return correct / total if total else 0
def save(self, filepath):
"""Save the trained model."""
data = {
'weights': self.model.weights,
'tagset': list(self.tagset)
}
with open(filepath, 'wb') as f:
pickle.dump(data, f)
print(f"Model saved to {filepath}")
@classmethod
def load(cls, filepath):
"""Load a trained model."""
tagger = cls()
with open(filepath, 'rb') as f:
data = pickle.load(f)
tagger.model.weights = defaultdict(lambda: defaultdict(float),
{k: defaultdict(float, v) for k, v in data['weights'].items()})
tagger.tagset = set(data['tagset'])
return tagger
def main():
data_dir = Path(__file__).parent.parent / 'data' / 'UD_Irish-TwittIrish-master'
# Load data
print("Loading TwittIrish data...")
train_sents = parse_conllu(data_dir / 'ga_twittirish-ud-train.conllu')
dev_sents = parse_conllu(data_dir / 'ga_twittirish-ud-dev.conllu')
test_sents = parse_conllu(data_dir / 'ga_twittirish-ud-test.conllu')
print(f" Train: {len(train_sents)} sentences")
print(f" Dev: {len(dev_sents)} sentences")
print(f" Test: {len(test_sents)} sentences")
# Train (more epochs for better convergence)
tagger = PerceptronTagger()
tagger.train(train_sents + dev_sents, n_epochs=20) # Use train+dev, more epochs
# Evaluate
dev_acc = tagger.evaluate(dev_sents)
test_acc = tagger.evaluate(test_sents)
print(f"\nResults:")
print(f" Dev accuracy: {dev_acc:.4f}")
print(f" Test accuracy: {test_acc:.4f}")
# Save model
model_path = Path(__file__).parent / 'perceptron_model.pkl'
tagger.save(model_path)
# Save results
results = {
'model': 'Averaged Perceptron',
'dev_accuracy': dev_acc,
'test_accuracy': test_acc,
'n_epochs': 20,
'train_sentences': len(train_sents),
}
results_path = Path(__file__).parent / 'results.json'
with open(results_path, 'w') as f:
json.dump(results, f, indent=2)
print(f"Results saved to {results_path}")
# Demo
print("\nDemo:")
demo_text = "Tá an aimsir go maith inniu"
tokens = demo_text.split()
tagged = tagger.tag(tokens)
print(f" Input: {demo_text}")
print(f" Tagged: {tagged}")
if __name__ == '__main__':
main()