Model Card for all-MiniLM-L6-v2 LiteRT

all-MiniLM-L6-v2 LiteRT model.

This can be run in txtai using the following code.

import txtai

embeddings = txtai.Embeddings(
  path="neuml/all-MiniLM-L6-v2-litert/all-MiniLM-L6-v2-fp16.tflite",
  content=True,
)
embeddings.index(documents())

# Run a query
embeddings.search("query to run")

Created using the following code.

#
# pip install litert-torch txtai
#
# See https://github.com/google-ai-edge/litert-torch
#

import argparse
import json
import os

import litert_torch
import torch

from ai_edge_quantizer import quantizer, recipe

from litert_torch.generative.quantize import quant_recipes
from torch import nn
from transformers import AutoTokenizer
from txtai.models import PoolingFactory
from txtai.util import Download


class Pooling(nn.Module):
    def __init__(self, path, device, **kwargs):
        super().__init__()
        self.model = PoolingFactory.create({"path": path, "device": device, "modelargs": kwargs})

        # Read max length parameter. Don't use tokenizer max length since model is exported as static shape
        config = f"{path}/sentence_bert_config.json"
        config = config if os.path.exists(config) else Download()(config)

        with open(config, encoding="utf-8") as f:
            data = json.load(f)
            self.maxlength = data["max_seq_length"]

    # pylint: disable=W0221
    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None):
        inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
        if token_type_ids is not None:
            inputs["token_type_ids"] = token_type_ids

        return self.model.forward(**inputs)


def export(args):
    model = Pooling(args.input, -1).float().eval()

    batch, maxlength = 4, model.maxlength
    inputs = (
        torch.ones(batch, maxlength, dtype=torch.int32),
        torch.ones(batch, maxlength, dtype=torch.int32),
        torch.ones(batch, maxlength, dtype=torch.int32),
    )

    base = os.path.basename(args.input)

    if args.quant == "int8":
        config, path = quant_recipes.full_dynamic_recipe(), f"{base}-int8.tflite"
    elif args.quant == "fp16":
        config, path = quant_recipes.full_fp16_recipe(), f"{base}-fp16.tflite"
    else:
        config, path = None, f"{base}-fp32.tflite"

    # Create output directory
    os.makedirs(args.output, exist_ok=True)
    path = os.path.join(args.output, path)

    # Save tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.input)
    tokenizer.save_pretrained(args.output)

    # Save model
    model = litert_torch.convert(model, inputs, quant_config=config)
    model.export(path)

    # Quantize to int4, if necessary
    if args.quant == "int4":
        qt = quantizer.Quantizer(path, recipe.dynamic_wi4_afp32())

        path = os.path.join(args.output, f"{base}-int4.tflite")
        qt.quantize().export_model(path, overwrite=True)

    return path


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", help="model path", required=True)
    parser.add_argument("--output", help="model output directory", required=True)
    parser.add_argument("--quant", help="model quantization", choices=["int4", "int8", "fp16"])
    args = parser.parse_args()

    export(args)
Downloads last month
27
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for NeuML/all-MiniLM-L6-v2-litert

Finetuned
(881)
this model