TinyLlama-VLM-LoRA

A vision-language model (VLM) built by fine-tuning TinyLlama (1.1B) with LoRA adapters on Flickr30k image-caption pairs, grafted onto a frozen CLIP vision encoder. This repository provides:

  • LoRA adapter weights for TinyLlama’s q_proj and v_proj layers (trained on 30K images, 3 epochs).
  • Projector (768→512) and Gate (512→4096) state-dicts (FP16) to convert CLIP’s [CLS] embedding into a “vision token.”
  • Tokenizer files matching TinyLlama’s BPE vocabulary, with <pad> set to <eos>.
  • Inference and evaluation scripts (with BLEU, ROUGE-1, F1, and perplexity).

Model Details

Model Description

TinyLlama-VLM-LoRA is a multi-modal extension of TinyLlama (1.1B-parameter Chat model). We freeze CLIP’s ViT-Base vision encoder and sandwich its [CLS] embedding into TinyLlama by:

  1. CLIP Vision Encoder (frozen)

    • Input: RGB image (224×224).
    • Output: 768-dim CLS embedding (FP16).
  2. Projector (nn.Linear(768 → 512, dtype=torch.float16))

    • Reduces CLIP’s [CLS] from 768→512.
  3. Gate (nn.Linear(512 → 4096, dtype=torch.float16))

    • Expands the 512 → TinyLlama’s hidden size (4096) to form a single “vision token” vector.
  4. TinyLlama + LoRA

    • We attach LoRA adapters (r=8, alpha=16, drop=0.1) onto TinyLlama’s q_proj and v_proj layers. During training, only LoRA, Projector, and Gate parameters are updated (TinyLlama’s base weights remain frozen).
    • The “vision token” is prepended to the usual token embeddings. The combined sequence (vision_token + text_tokens[0…L−2]) is fed to TinyLlama to predict the image caption (text_tokens[1…L−1] masked appropriately).

File/Version Breakdown

  • adapter_config.json, adapter_model.safetensors
    LoRA configuration + learned delta weights for TinyLlama’s q_proj/v_proj (≈4 MB).

  • proj_final.pt (≈1.6 MB)
    nn.Linear(768→512) state_dict after final epoch (FP16).

  • gate_final.pt (≈4.2 MB)
    nn.Linear(512→4096) state_dict after final epoch (FP16).

  • Tokenizer files (tokenizer.json, tokenizer.model, tokenizer_config.json, special_tokens_map.json)
    Exact BPE vocabulary, merges, and special token mappings used during fine-tuning.

  • finetuned_qvlam_flickr30k_final/
    A directory containing:

    • LoRA adapter (adapter_config.json, adapter_model.safetensors)
    • tokenizer.json + related tokenizer artifacts (This folder can be passed directly to PeftModel.from_pretrained and AutoTokenizer.from_pretrained.)

Uses

TinyLlama-VLM-LoRA is intended for:

  1. Caption Generation
    Given an arbitrary RGB image, produce a descriptive caption in plain English via beam search.

  2. Downstream Fine-Tuning
    Users can further adapt the LoRA adapters on a new, smaller image-caption dataset, or graft additional modules on top of the vision token.

  3. Research / Educational
    Demonstrates a lightweight VLM pipeline (CLIP → TinyLlama) using LoRA, projector, and gate. Useful as a starting point for more advanced multi-modal research (e.g., adding a quantum layer).

Example Inference (Direct Use)

# Clone this repository
git clone https://huggingface.co/<your-username>/TinyLlama-VLM-LoRA
cd TinyLlama-VLM-LoRa
pip install torch transformers peft pillow nltk rouge-score tqdm

# Download an example image, e.g. test_images/00001.jpg
# Then run the inference script:
python evaluate_vlm.py \
  --ckpt_dir "TinyLlama-VLM-LoRA/finetuned_qvlam_flickr30k_final" \
  --test_image_dir "test_images" \
  --ref_json "refs.json"

# Or to generate a single caption manually:
python - <<’PYCODE’
import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# Load components
device = "cuda" if torch.cuda.is_available() else "cpu"

# 1) CLIP
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float16).to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# 2) Base TinyLlama + LoRA adapter
base_llama = AutoModelForCausalLM.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
).to(device)
llama = PeftModel.from_pretrained(base_llama, "TinyLlama-VLM-LoRA/adapter_model.safetensors").to(device)
llama.eval()

# 3) Projector & Gate
projector = torch.nn.Linear(768, 512, dtype=torch.float16).to(device)
gate      = torch.nn.Linear(512, 4096, dtype=torch.float16).to(device)
projector.load_state_dict(torch.load("TinyLlama-VLM-LoRA/proj_final.pt", map_location=device))
gate.load_state_dict(torch.load("TinyLlama-VLM-LoRA/gate_final.pt", map_location=device))

# 4) Tokenizer
tokenizer = AutoTokenizer.from_pretrained("TinyLlama-VLM-LoRA/finetuned_qvlam_flickr30k_final")
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

# 5) Build VLM wrapper class
class SimpleVisionLanguageModel(torch.nn.Module):
    def __init__(self, clip_model, text_model, projector, gate, device, max_len, prompt):
        super().__init__()
        self.vision = clip_model.vision_model
        for p in self.vision.parameters():
            p.requires_grad = False

        self.text_model = text_model
        self.projector = projector.to(device)
        self.gate = gate.to(device)
        self.device = device
        self.max_length = max_len
        self.prompt_text = prompt

    def forward(self, pixel_values, full_input_ids, attention_mask, prompt_len, labels=None):
        B, L = full_input_ids.size()
        with torch.no_grad():
            vision_out = self.vision(pixel_values=pixel_values)
            cls_embed = vision_out.last_hidden_state[:, 0, :]

        cls_fp32 = cls_embed.to(torch.float32)
        proj_fp32 = self.projector(cls_fp32)
        gate_fp32 = self.gate(proj_fp32)
        vision_token = gate_fp32.to(torch.float16)

        input_ids_trunc = full_input_ids[:, :-1]
        text_embeds = self.text_model.get_input_embeddings()(input_ids_trunc)
        combined_embeds = torch.cat([vision_token.unsqueeze(1), text_embeds], dim=1)

        vision_attn = torch.ones((B, 1), device=self.device, dtype=attention_mask.dtype)
        combined_attn = torch.cat([vision_attn, attention_mask[:, :-1]], dim=1)

        if labels is None:
            labels = torch.full((B, L), -100, dtype=torch.long, device=self.device)
            for i in range(B):
                P = prompt_len[i].item()
                total_nonpad = attention_mask[i].sum().item()
                raw_C = total_nonpad - P
                max_C = (L - 1 - P)
                C = min(raw_C, max(0, max_C))
                if C > 0:
                    start = 1 + P
                    end = 1 + P + C
                    labels[i, start:end] = full_input_ids[i, P:P + C]

        outputs = self.text_model(
            inputs_embeds=combined_embeds,
            attention_mask=combined_attn,
            labels=labels
        )
        return outputs

# Instantiate VLM
vlm = SimpleVisionLanguageModel(
    clip_model=clip_model,
    text_model=llama,
    projector=projector,
    gate=gate,
    device=device,
    max_len=50,
    prompt="Describe the image: "
).to(device)
vlm.eval()

# Load a test image
img = Image.open("test_images/00001.jpg").convert("RGB")
pix = clip_processor(images=img, return_tensors="pt").pixel_values.to(device).to(torch.float16)

# Tokenize prompt
prompt_tok = tokenizer("Describe the image: ", return_tensors="pt", max_length=50, truncation=True)
prompt_ids = prompt_tok.input_ids.to(device)
prompt_attn = prompt_tok.attention_mask.to(device)
prompt_len = torch.tensor([prompt_ids.size(1)], device=device)

# Forward to get combined embeddings + attention mask
with torch.no_grad():
    # Build a dummy “full_input_ids” that’s just prompt_ids + padding (won’t be used for generation)
    padding_ids = torch.full((1, 50), tokenizer.pad_token_id, dtype=torch.long, device=device)
    padding_ids[:, : prompt_ids.size(1)] = prompt_ids
    padding_attn = torch.full((1, 50), 0, dtype=torch.long, device=device)
    padding_attn[:, : prompt_attn.size(1)] = prompt_attn

    outputs = vlm(
        pixel_values=pix,
        full_input_ids=padding_ids,
        attention_mask=padding_attn,
        prompt_len=prompt_len,
        labels=None
    )
    combined_embeds = outputs.hidden_states[:, :, :]  # first return value in generation mode
    combined_attn  = torch.cat([torch.ones((1, 1), device=device, dtype=padding_attn.dtype),
                                padding_attn[:, :-1]], dim=1)

    # Generate with TinyLlama
    generated_ids = llama.generate(
        inputs_embeds=combined_embeds,
        attention_mask=combined_attn,
        max_length=50,
        num_beams=3,
        early_stopping=True,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        use_cache=True
    )
    caption = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

print("Generated caption:", caption)
PYCODE
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for sriram7737/TinyLlama-VLM-LoRA

Adapter
(1313)
this model