ranwakhaled/EgyM3AV
Viewer • Updated • 12.7k • 64
How to use ranwakhaled/gemma-vl-exp1 with Transformers:
# Load model directly
from transformers import AutoModel
model = AutoModel.from_pretrained("ranwakhaled/gemma-vl-exp1", dtype="auto")This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
[More Information Needed]
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
# Select device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and processor
model = AutoModelForImageTextToText.from_pretrained(
"ranwakhaled/gemma-vl-exp1",
device_map=device
).eval()
processor = AutoProcessor.from_pretrained("ranwakhaled/gemma-vl-exp1")
# Example inputs
# - `image`: PIL image or tensor
# - `PROMPT`: text instruction
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": PROMPT},
],
}
]
# Prepare inputs using chat template
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
# Move inputs to the correct device
inputs = {k: v.to(model.device) for k, v in inputs.items()}
# Ensure image tensor matches model dtype (important for some GPUs)
if "pixel_values" in inputs:
inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype)
# Track input length to isolate generated tokens
input_length = inputs["input_ids"].shape[-1]
# Generate output
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
)
# Remove prompt tokens from the generated sequence
generated_tokens = outputs[0][input_length:]
# Decode to text
decoded_output = processor.decode(
generated_tokens,
skip_special_tokens=True
)
print(decoded_output)
[More Information Needed]