| | --- |
| | library_name: transformers |
| | tags: |
| | - embedding |
| | - semantic-similarity |
| | - multilingual |
| | - gemma2 |
| | - mitra |
| | --- |
| | |
| | # Model Card for gemma2-mitra-embedding |
| |
|
| | Multilingual sentence embedding model based on Gemma 2, designed for semantic similarity and retrieval. It is used in the Mitra alignment stack to embed sentences in languages such as Sanskrit, Tibetan, Pali, Chinese, and English for cross-lingual sentence alignment and similarity search. |
| |
|
| | ### Direct Use |
| |
|
| | - **Semantic similarity:** Encode sentences and compare them via cosine similarity (embeddings are L2-normalized). |
| | - **Retrieval:** Encode queries with the **query** template and corpus passages with the **corpus** (no instruction) format; retrieve by nearest-neighbor search (e.g. FAISS). |
| | - **Multilingual alignment:** Used in this repo to embed source and target sentences for sentence-level alignment (e.g. Buddhist texts across Sanskrit, Tibetan, Pali, Chinese, English). |
| |
|
| | ### Downstream Use [optional] |
| |
|
| | - Integration into alignment pipelines (e.g. Bertalign-style alignment with this embedder). |
| | - RAG or search systems that need multilingual, instruction-aware query/corpus embeddings. |
| | - Any application that consumes L2-normalized sentence vectors from this model. |
| |
|
| | ### Recommendations |
| |
|
| | - Use the exact prompt format (see “How to Get Started”) for queries and corpus. |
| | - Users should be aware of potential biases and limitations; evaluate on their own data and languages before deployment. |
| |
|
| | ## How to Get Started with the Model |
| |
|
| | ### Prompt / template pattern |
| |
|
| | The model expects **asymmetric** inputs: |
| |
|
| | - **Query (instruction + query):** |
| | `<instruct>Please find the semantically most similar text in {language}.\n<query>{sentence_text}` |
| | where `{language}` is a full language name (e.g. "Sanskrit", "Tibetan", "English", "Chinese", "Pali", "Hindi") and `{sentence_text}` is the query string. |
| |
|
| | - **Corpus:** |
| | Use the raw sentence (or passage) text **only**, with no `<instruct>` or `<query>` wrapper. |
| |
|
| | ### Example (with 8-bit and Hugging Face Transformers) |
| |
|
| | ```python |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
| | import torch |
| | |
| | model_path = "gemma2-mitra-embedding" |
| | quantization_config = BitsAndBytesConfig( |
| | load_in_8bit=True, |
| | llm_int8_threshold=6.0, |
| | llm_int8_has_fp16_weight=False, |
| | ) |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_path, |
| | quantization_config=quantization_config, |
| | device_map={"": 0}, |
| | torch_dtype=torch.float16, |
| | trust_remote_code=True, |
| | ) |
| | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | # Query encoding (one sentence) |
| | language = "Sanskrit" # or Tibetan, English, Chinese, Pali, Hindi |
| | text = "Your query sentence here." |
| | prompt = f"<instruct>Please find the semantically most similar text in {language}.\\n<query>{text}" |
| | inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(model.device) |
| | with torch.no_grad(): |
| | outputs = model(**inputs, output_hidden_states=True) |
| | # Last token embedding (last non-padded) |
| | last_token_idx = inputs["attention_mask"].sum(dim=1) - 1 |
| | embedding = outputs.hidden_states[-1][torch.arange(last_token_idx.size(0)), last_token_idx] |
| | # L2-normalize for cosine similarity |
| | embedding = embedding / embedding.norm(dim=-1, keepdim=True) |
| | ``` |
| |
|
| | For **corpus** sentences, pass only the raw text (no `<instruct>`/`<query>`), then take the last token hidden state and L2-normalize the same way. |
| |
|
| | Alternatively, use **FlagEmbedding**’s `FlagLLMModel` with this model path for `encode_queries` and `encode_corpus` (see [FlagEmbedding](https://github.com/FlagOpen/FlagEmbedding)). |
| |
|
| |
|
| | ### Model Architecture and Objective |
| |
|
| | - **Architecture:** Gemma 2 (`Gemma2Model`), used as an encoder: input text → last-token hidden state → L2-normalized embedding. |
| | - **Config (from repo):** `hidden_size=3584`, `num_hidden_layers=42`, `num_attention_heads=16`, `num_key_value_heads=8`, `intermediate_size=14336`, `head_dim=256`, `max_position_embeddings=8192`, `sliding_window=4096`, `vocab_size=256002` (includes special tokens `<instruct>`, `<query>`). |
| | - **Special tokens:** `<instruct>`, `<query>` (see `special_tokens_map.json` / `added_tokens.json` in the model dir). |
| | - **Objective:** Dense retrieval / semantic similarity (asymmetric query/corpus encoding). |
| |
|