TinyLlama-VLM-LoRA
A vision-language model (VLM) built by fine-tuning TinyLlama (1.1B) with LoRA adapters on Flickr30k image-caption pairs, grafted onto a frozen CLIP vision encoder. This repository provides:
- LoRA adapter weights for TinyLlama’s q_proj and v_proj layers (trained on 30K images, 3 epochs).
- Projector (768→512) and Gate (512→4096) state-dicts (FP16) to convert CLIP’s [CLS] embedding into a “vision token.”
- Tokenizer files matching TinyLlama’s BPE vocabulary, with
<pad>set to<eos>. - Inference and evaluation scripts (with BLEU, ROUGE-1, F1, and perplexity).
Model Details
Model Description
TinyLlama-VLM-LoRA is a multi-modal extension of TinyLlama (1.1B-parameter Chat model). We freeze CLIP’s ViT-Base vision encoder and sandwich its [CLS] embedding into TinyLlama by:
CLIP Vision Encoder (frozen)
- Input: RGB image (224×224).
- Output: 768-dim CLS embedding (FP16).
Projector (
nn.Linear(768 → 512, dtype=torch.float16))- Reduces CLIP’s [CLS] from 768→512.
Gate (
nn.Linear(512 → 4096, dtype=torch.float16))- Expands the 512 → TinyLlama’s hidden size (4096) to form a single “vision token” vector.
TinyLlama + LoRA
- We attach LoRA adapters (r=8, alpha=16, drop=0.1) onto TinyLlama’s q_proj and v_proj layers. During training, only LoRA, Projector, and Gate parameters are updated (TinyLlama’s base weights remain frozen).
- The “vision token” is prepended to the usual token embeddings. The combined sequence (vision_token + text_tokens[0…L−2]) is fed to TinyLlama to predict the image caption (text_tokens[1…L−1] masked appropriately).
File/Version Breakdown
adapter_config.json,adapter_model.safetensors
LoRA configuration + learned delta weights for TinyLlama’s q_proj/v_proj (≈4 MB).proj_final.pt(≈1.6 MB)nn.Linear(768→512)state_dict after final epoch (FP16).gate_final.pt(≈4.2 MB)nn.Linear(512→4096)state_dict after final epoch (FP16).Tokenizer files (
tokenizer.json,tokenizer.model,tokenizer_config.json,special_tokens_map.json)
Exact BPE vocabulary, merges, and special token mappings used during fine-tuning.finetuned_qvlam_flickr30k_final/
A directory containing:- LoRA adapter (
adapter_config.json,adapter_model.safetensors) tokenizer.json+ related tokenizer artifacts (This folder can be passed directly toPeftModel.from_pretrainedandAutoTokenizer.from_pretrained.)
- LoRA adapter (
Uses
TinyLlama-VLM-LoRA is intended for:
Caption Generation
Given an arbitrary RGB image, produce a descriptive caption in plain English via beam search.Downstream Fine-Tuning
Users can further adapt the LoRA adapters on a new, smaller image-caption dataset, or graft additional modules on top of the vision token.Research / Educational
Demonstrates a lightweight VLM pipeline (CLIP → TinyLlama) using LoRA, projector, and gate. Useful as a starting point for more advanced multi-modal research (e.g., adding a quantum layer).
Example Inference (Direct Use)
# Clone this repository
git clone https://huggingface.co/<your-username>/TinyLlama-VLM-LoRA
cd TinyLlama-VLM-LoRa
pip install torch transformers peft pillow nltk rouge-score tqdm
# Download an example image, e.g. test_images/00001.jpg
# Then run the inference script:
python evaluate_vlm.py \
--ckpt_dir "TinyLlama-VLM-LoRA/finetuned_qvlam_flickr30k_final" \
--test_image_dir "test_images" \
--ref_json "refs.json"
# Or to generate a single caption manually:
python - <<’PYCODE’
import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor, AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
# Load components
device = "cuda" if torch.cuda.is_available() else "cpu"
# 1) CLIP
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=torch.float16).to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# 2) Base TinyLlama + LoRA adapter
base_llama = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
).to(device)
llama = PeftModel.from_pretrained(base_llama, "TinyLlama-VLM-LoRA/adapter_model.safetensors").to(device)
llama.eval()
# 3) Projector & Gate
projector = torch.nn.Linear(768, 512, dtype=torch.float16).to(device)
gate = torch.nn.Linear(512, 4096, dtype=torch.float16).to(device)
projector.load_state_dict(torch.load("TinyLlama-VLM-LoRA/proj_final.pt", map_location=device))
gate.load_state_dict(torch.load("TinyLlama-VLM-LoRA/gate_final.pt", map_location=device))
# 4) Tokenizer
tokenizer = AutoTokenizer.from_pretrained("TinyLlama-VLM-LoRA/finetuned_qvlam_flickr30k_final")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
# 5) Build VLM wrapper class
class SimpleVisionLanguageModel(torch.nn.Module):
def __init__(self, clip_model, text_model, projector, gate, device, max_len, prompt):
super().__init__()
self.vision = clip_model.vision_model
for p in self.vision.parameters():
p.requires_grad = False
self.text_model = text_model
self.projector = projector.to(device)
self.gate = gate.to(device)
self.device = device
self.max_length = max_len
self.prompt_text = prompt
def forward(self, pixel_values, full_input_ids, attention_mask, prompt_len, labels=None):
B, L = full_input_ids.size()
with torch.no_grad():
vision_out = self.vision(pixel_values=pixel_values)
cls_embed = vision_out.last_hidden_state[:, 0, :]
cls_fp32 = cls_embed.to(torch.float32)
proj_fp32 = self.projector(cls_fp32)
gate_fp32 = self.gate(proj_fp32)
vision_token = gate_fp32.to(torch.float16)
input_ids_trunc = full_input_ids[:, :-1]
text_embeds = self.text_model.get_input_embeddings()(input_ids_trunc)
combined_embeds = torch.cat([vision_token.unsqueeze(1), text_embeds], dim=1)
vision_attn = torch.ones((B, 1), device=self.device, dtype=attention_mask.dtype)
combined_attn = torch.cat([vision_attn, attention_mask[:, :-1]], dim=1)
if labels is None:
labels = torch.full((B, L), -100, dtype=torch.long, device=self.device)
for i in range(B):
P = prompt_len[i].item()
total_nonpad = attention_mask[i].sum().item()
raw_C = total_nonpad - P
max_C = (L - 1 - P)
C = min(raw_C, max(0, max_C))
if C > 0:
start = 1 + P
end = 1 + P + C
labels[i, start:end] = full_input_ids[i, P:P + C]
outputs = self.text_model(
inputs_embeds=combined_embeds,
attention_mask=combined_attn,
labels=labels
)
return outputs
# Instantiate VLM
vlm = SimpleVisionLanguageModel(
clip_model=clip_model,
text_model=llama,
projector=projector,
gate=gate,
device=device,
max_len=50,
prompt="Describe the image: "
).to(device)
vlm.eval()
# Load a test image
img = Image.open("test_images/00001.jpg").convert("RGB")
pix = clip_processor(images=img, return_tensors="pt").pixel_values.to(device).to(torch.float16)
# Tokenize prompt
prompt_tok = tokenizer("Describe the image: ", return_tensors="pt", max_length=50, truncation=True)
prompt_ids = prompt_tok.input_ids.to(device)
prompt_attn = prompt_tok.attention_mask.to(device)
prompt_len = torch.tensor([prompt_ids.size(1)], device=device)
# Forward to get combined embeddings + attention mask
with torch.no_grad():
# Build a dummy “full_input_ids” that’s just prompt_ids + padding (won’t be used for generation)
padding_ids = torch.full((1, 50), tokenizer.pad_token_id, dtype=torch.long, device=device)
padding_ids[:, : prompt_ids.size(1)] = prompt_ids
padding_attn = torch.full((1, 50), 0, dtype=torch.long, device=device)
padding_attn[:, : prompt_attn.size(1)] = prompt_attn
outputs = vlm(
pixel_values=pix,
full_input_ids=padding_ids,
attention_mask=padding_attn,
prompt_len=prompt_len,
labels=None
)
combined_embeds = outputs.hidden_states[:, :, :] # first return value in generation mode
combined_attn = torch.cat([torch.ones((1, 1), device=device, dtype=padding_attn.dtype),
padding_attn[:, :-1]], dim=1)
# Generate with TinyLlama
generated_ids = llama.generate(
inputs_embeds=combined_embeds,
attention_mask=combined_attn,
max_length=50,
num_beams=3,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True
)
caption = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print("Generated caption:", caption)
PYCODE
- Downloads last month
- -
Model tree for sriram7737/TinyLlama-VLM-LoRA
Base model
TinyLlama/TinyLlama-1.1B-Chat-v1.0