AbdulElahGwaith's picture
Upload folder using huggingface_hub
a9bd396 verified

This model was released on 2024-10-08 and added to Hugging Face Transformers on 2024-12-06.

PyTorch FlashAttention SDPA

Aria

Aria is a multimodal mixture-of-experts (MoE) model. The goal of this model is to open-source a training recipe for creating a multimodal native model from scratch. Aria has 3.9B and 3.5B activated parameters per visual and text token respectively. Text is handled by a MoE decoder and visual inputs are handled by a lightweight visual encoder. It is trained in 4 stages, language pretraining, multimodal pretraining, multimodal long-context pretraining, and multimodal post-training.

You can find all the original Aria checkpoints under the Aria organization.

Click on the Aria models in the right sidebar for more examples of how to apply Aria to different multimodal tasks.

The example below demonstrates how to generate text based on an image with [Pipeline] or the [AutoModel] class.

import torch
from transformers import pipeline

pipeline = pipeline(
    "image-to-text",
    model="rhymes-ai/Aria",
    device=0,
    dtype=torch.bfloat16
)
pipeline(
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg",
    text="What is shown in this image?"
)
import torch
from transformers import AutoModelForCausalLM, AutoProcessor

model = AutoModelForCausalLM.from_pretrained(
    "rhymes-ai/Aria",
    device_map="auto",
    dtype=torch.bfloat16,
    attn_implementation="sdpa"
)

processor = AutoProcessor.from_pretrained("rhymes-ai/Aria")

messages = [
    {
        "role": "user", "content": [
            {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
            {"type": "text", "text": "What is shown in this image?"},
        ]
    },
]

inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt")
ipnuts = inputs.to(model.device, torch.bfloat16)

output = model.generate(
    **inputs,
    max_new_tokens=15,
    stop_strings=["<|im_end|>"],
    tokenizer=processor.tokenizer,
    do_sample=True,
    temperature=0.9,
)
output_ids = output[0][inputs["input_ids"].shape[1]:]
response = processor.decode(output_ids, skip_special_tokens=True)
print(response)

Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the Quantization overview for more available quantization backends.

The example below uses torchao to only quantize the weights to int4 and the rhymes-ai/Aria-sequential_mlp checkpoint. This checkpoint replaces grouped GEMM with torch.nn.Linear layers for easier quantization.

# pip install torchao
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoProcessor

quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
model = AutoModelForCausalLM.from_pretrained(
    "rhymes-ai/Aria-sequential_mlp",
    dtype=torch.bfloat16,
    device_map="auto",
    quantization_config=quantization_config
)
processor = AutoProcessor.from_pretrained(
    "rhymes-ai/Aria-sequential_mlp",
)

messages = [
    {
        "role": "user", "content": [
            {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pipeline-cat-chonk.jpeg"},
            {"type": "text", "text": "What is shown in this image?"},
        ]
    },
]

inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt")
inputs = inputs.to(model.device, torch.bfloat16)

output = model.generate(
    **inputs,
    max_new_tokens=15,
    stop_strings=["<|im_end|>"],
    tokenizer=processor.tokenizer,
    do_sample=True,
    temperature=0.9,
)
output_ids = output[0][inputs["input_ids"].shape[1]:]
response = processor.decode(output_ids, skip_special_tokens=True)
print(response)

AriaImageProcessor

[[autodoc]] AriaImageProcessor

AriaProcessor

[[autodoc]] AriaProcessor - call

AriaTextConfig

[[autodoc]] AriaTextConfig

AriaConfig

[[autodoc]] AriaConfig

AriaTextModel

[[autodoc]] AriaTextModel

AriaModel

[[autodoc]] AriaModel

AriaTextForCausalLM

[[autodoc]] AriaTextForCausalLM

AriaForConditionalGeneration

[[autodoc]] AriaForConditionalGeneration - forward