Jeblest's picture
File import fixed
a367880
import torch
from transformers import (
AutoProcessor,
Qwen2_5_VLForConditionalGeneration,
BitsAndBytesConfig
)
from peft import PeftModel
import os
from PIL import Image
# Your Hugging Face repo
MODEL_REPO = "Jeblest/Qwen-2.5-7B-Instruct-fine-tune-image-caption"
BASE_MODEL = "Qwen/Qwen2.5-VL-7B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Quantization setup
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
# 🔄 Load base model with quantization
base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
BASE_MODEL,
device_map="auto",
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
# 🔄 Load LoRA adapters directly from your Hugging Face repo
model = PeftModel.from_pretrained(
base_model,
MODEL_REPO, # This will download LoRA adapter config & weights
torch_dtype=torch.bfloat16,
)
model.eval()
# Load processor
processor = AutoProcessor.from_pretrained(BASE_MODEL)
if processor.tokenizer.pad_token is None:
processor.tokenizer.pad_token = processor.tokenizer.eos_token
class SingleImageCollator:
"""
A data collator for single-image inference (Gradio or custom input).
"""
def __init__(self, processor, user_query: str = "Generate a detailed caption based on this image."):
self.processor = processor
self.user_query = user_query
def __call__(self, image: Image.Image):
image = image.convert("RGB").resize((448, 448))
messages = [{"role": "user", "content": [
{"type": "text", "text": self.user_query},
{"type": "image", "image": image}
]}]
text_input = self.processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
return self.processor(text=text_input.strip(), images=[image], return_tensors="pt", padding=True, padding_side="left")
def infer_single_image(
model,
processor,
image: Image.Image,
prompt: str = "Generate a detailed caption based on this image.",
max_new_tokens: int = 100,
temperature: float = 0.3,
top_k: int = 30,
top_p: float = 0.8,
repetition_penalty=1.1,
length_penalty=1.0,
device: str = None
) -> str:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
collator = SingleImageCollator(processor, user_query=prompt)
inputs = collator(image)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
pad_token_id=processor.tokenizer.pad_token_id
)
generated_text = processor.batch_decode(
generated_ids[:, inputs["input_ids"].shape[1]:],
skip_special_tokens=True
)[0]
return generated_text.strip()