ReCasePunct 1 Flash

We introduce ReCasePunct 1 Flash, our first model capable of punctuation and casing restoration!

Given lowercase and non-punctated English text of any length (but it's not infinite length, as far as I tested), this model can predict punctuation and casing, and it's impressive!

It also runs very fast on CPU too!

Use cases could be for ASR tasks (some models give text without casing and punctuation, like on auto-generated subtitles for YouTube videos from 2023/2024/2025)

Limitations

This model was trained ONLY on English Tatoeba data and doesn't do well for other languages.

Also, it doesn't do perfectly sometimes (especially with proper nouns like "Minecraft").

We might train a multi-lingual and better ReCasePunct model next!

How To Run It

Code by Gemini 2.5 Flash:

from transformers import AutoTokenizer, AlbertConfig, AlbertModel
import torch
import torch.nn as nn
import re
import numpy as np
from safetensors.torch import load_file # Import safe_load for safetensors
from huggingface_hub import hf_hub_download # Import hf_hub_download

# Redefine the model class (must be the same as during training)
class AlbertForPunctuationAndCasing(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_punctuation_labels = config.num_punctuation_labels
        self.num_casing_labels = config.num_casing_labels

        # Initialize AlbertModel directly with the config provided
        # This config should ideally reflect the true albert-large-v2 architecture
        self.albert = AlbertModel(config)
        self.dropout = nn.Dropout(config.classifier_dropout_prob)

        self.punctuation_classifier = nn.Linear(config.hidden_size, self.num_punctuation_labels)
        self.casing_classifier = nn.Linear(config.hidden_size, self.num_casing_labels)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        casing_labels=None,
        punctuation_labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else True

        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)
        punctuation_logits = self.punctuation_classifier(sequence_output)
        casing_logits = self.casing_classifier(sequence_output)

        loss = None
        if casing_labels is not None and punctuation_labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)

            punctuation_loss = loss_fct(punctuation_logits.view(-1, self.num_punctuation_labels), punctuation_labels.view(-1))

            casing_loss = loss_fct(casing_logits.view(-1, self.num_casing_labels), casing_labels.view(-1))

            loss = punctuation_loss + casing_loss

        if not return_dict:
            output = (punctuation_logits, casing_logits) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        result = {
            "loss": loss,
            "punctuation_logits": punctuation_logits,
            "casing_logits": casing_logits,
        }
        if outputs.hidden_states is not None:
            result["hidden_states"] = outputs.hidden_states
        if outputs.attentions is not None:
            result["attentions"] = outputs.attentions
        return result


# --- Configuration and Mappings (must be the same as during training) ---
punctuation_labels = ['O', '.', ',', '?', '!', ';', ':', '-', '"', '(', ')', '/', '\\']
punctuation_to_id = {label: i for i, label in enumerate(punctuation_labels)}
id_to_punctuation = {i: label for i, label in enumerate(punctuation_labels)}

casing_labels = ['O', 'CAP', 'UPPER']
casing_to_id = {label: i for i, label in enumerate(casing_labels)}
id_to_casing = {i: label for i, label in enumerate(casing_labels)}

model_checkpoint = 'albert-large-v2'

# Define the Hugging Face repository ID
hf_repo_id = "MihaiPopa-1/ReCasePunct-1-Flash"

# Load tokenizer from Hugging Face Hub
tokenizer = AutoTokenizer.from_pretrained(hf_repo_id)

# --- CORRECTED MODEL CONFIG LOADING ---
# 1. Load the base ALBERT Large v2 configuration to get correct architecture defaults (like hidden_size)
config = AlbertConfig.from_pretrained(model_checkpoint)

# 2. Set the custom labels on this correctly sized config
config.num_punctuation_labels = len(punctuation_labels)
config.num_casing_labels = len(casing_labels)

# Instantiate the custom model with the corrected config
model = AlbertForPunctuationAndCasing(config)

# Download the model.safetensors file from the Hub
safetensors_path = hf_hub_download(repo_id=hf_repo_id, filename="model.safetensors")

# Load the full state dictionary into the custom model
model.load_state_dict(load_file(safetensors_path, device='cpu'))
model.eval()


def clean_text(text):
    """Removes punctuation and converts text to lowercase for the model input."""
    text = text.lower()
    text = re.sub(r'[\.,\?!\-;:"\(\)\[\]\{\}\/\\]', '', text) # Remove common punctuation
    text = re.sub(r'\s+', ' ', text).strip() # Replace multiple spaces with single space
    return text

def predict_punctuation_and_casing(text, model, tokenizer, id_to_punctuation, id_to_casing):
    # Clean the input text similar to how training data was prepared
    cleaned_text_input = clean_text(text)
    words_in_cleaned_text = cleaned_text_input.split()

    # Tokenize the input
    tokenized_input = tokenizer(
        cleaned_text_input,
        return_offsets_mapping=True,
        truncation=True,
        max_length=tokenizer.model_max_length,
        return_tensors="pt"
    )

    # Perform inference
    with torch.no_grad():
        outputs = model(
            input_ids=tokenized_input['input_ids'],
            attention_mask=tokenized_input['attention_mask']
        )

    punctuation_logits = outputs['punctuation_logits'].squeeze(0).numpy()
    casing_logits = outputs['casing_logits'].squeeze(0).numpy()

    punctuation_predictions = np.argmax(punctuation_logits, axis=-1)
    casing_predictions = np.argmax(casing_logits, axis=-1)

    # Initialize output list for reconstructed sentence
    reconstructed_text_parts = []
    current_word_idx = 0

    # Iterate over tokens and apply predictions
    for token_idx, (token_start, token_end) in enumerate(tokenized_input['offset_mapping'].squeeze(0).numpy()):
        if token_start == 0 and token_end == 0: # Skip special tokens like [CLS], [SEP]
            continue

        # Get the word from the original cleaned text (not subword)
        # This requires careful alignment if a single word maps to multiple tokens
        # and apply label to the last token of a word.
        
        # Find the actual word from the input_text_single corresponding to this token
        token_text = cleaned_text_input[token_start:token_end]
        
        # Check if this token is the beginning of a word we care about
        if current_word_idx < len(words_in_cleaned_text) and words_in_cleaned_text[current_word_idx].startswith(token_text):
            word = words_in_cleaned_text[current_word_idx]
            
            # Apply casing
            casing_pred_label = id_to_casing[casing_predictions[token_idx]]
            if casing_pred_label == 'CAP':
                word = word.capitalize()
            elif casing_pred_label == 'UPPER':
                word = word.upper()

            # Apply punctuation (only to the last subword token of a word)
            # This is a heuristic and might need refinement for complex tokenizations
            next_token_word_idx = -1
            if token_idx + 1 < len(tokenized_input['offset_mapping'].squeeze(0).numpy()):
                 next_token_start, _ = tokenized_input['offset_mapping'].squeeze(0).numpy()[token_idx+1]
                 # Check if the next token starts after the current word ends in the cleaned_text_input
                 # or if the next token is a special token
                 if next_token_start >= token_end or (tokenized_input['input_ids'].squeeze(0)[token_idx+1].item() in [tokenizer.cls_token_id, tokenizer.sep_token_id]):
                     # This is likely the last token of the current word
                     punctuation_pred_label = id_to_punctuation[punctuation_predictions[token_idx]]
                     if punctuation_pred_label != 'O':
                         word += punctuation_pred_label
            else:
                # Last token in the sequence
                punctuation_pred_label = id_to_punctuation[punctuation_predictions[token_idx]]
                if punctuation_pred_label != 'O':
                    word += punctuation_pred_label

            reconstructed_text_parts.append(word)
            current_word_idx += 1

    return ' '.join(reconstructed_text_parts).replace(' .', '.').replace(' ,', ',').replace(' ?', '?').replace(' !', '!').replace(' ;', ';').replace(' :', ':').replace(' -', '-').replace(' "', '"').replace('( ', '(').replace(' )', ')').replace(' /', '/').replace(' \\', '\\')

# --- Test Case for a single sentence ---
single_sample_sentence = "replace me by whatever sentence you like"

print(f"Original: {single_sample_sentence}")
print(f"Predicted: {predict_punctuation_and_casing(single_sample_sentence, model, tokenizer, id_to_punctuation, id_to_casing)}\n")

Should give: Replace me by whatever sentence you like.

Examples

Original Sentence Predicted Sentence
this is a test of punctuation prediction for english how are you doing today This is a test of punctuation prediction for English. How are you doing today?
i love running this on t4 gpu and so for this goal we might make a better and more accurate model in the future I love running this on T4 GPU and so, for this goal, we might make a better and more accurate model in the future.
so imagine this we live in a world with complex models yet this model does punctuation and casing prediction for english and it's very small at just only 18 million parameters So, imagine this, we live in a world with complex models. Yet this model does punctuation and casing prediction for English, and it's very small at just only 18 million parameters.

Evaluation Results

Epoch Training Loss Validation Loss Punctuation Accuracy Casing Accuracy Overall Accuracy
1 0.072175 0.070485 0.642053 (64.21%) 0.638791 (63.88%) 0.640422 (64.04%)
2 0.052846 0.063811 0.642343 (64.23%) 0.640475 (64.05%) 0.641409 (64.14%)
3 0.031407 0.062892 0.640457 (64.05%) 0.640098 (64.01%) 0.640278 (64.03%)
Downloads last month

-

Downloads are not tracked for this model. How to track
Safetensors
Model size
17.7M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for MihaiPopa-1/ReCasePunct-1-Flash

Finetuned
(26)
this model