| # LLaVA-LoRA Adapter | |
| This is a LoRA adapter for the LLaVA model, fine-tuned for spatial description tasks. | |
| ## Base Model | |
| This adapter is trained on top of [llava-hf/llava-1.5-7b-hf](https://huggingface.co/llava-hf/llava-1.5-7b-hf). | |
| ## Training | |
| The model was fine-tuned using LoRA with the following configuration: | |
| - Rank: 8 | |
| - Alpha: 32 | |
| - Target modules: q_proj, v_proj, k_proj | |
| - Dataset: PersReFex validation set | |
| ## Usage | |
| ```python | |
| from peft import PeftModel | |
| from transformers import AutoProcessor, LlavaForConditionalGeneration | |
| import torch | |
| # Load base model | |
| base_model = LlavaForConditionalGeneration.from_pretrained( | |
| "llava-hf/llava-1.5-7b-hf", | |
| torch_dtype=torch.bfloat16 | |
| ).to('cuda') | |
| processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") | |
| # Load LoRA adapter | |
| model = PeftModel.from_pretrained( | |
| base_model, | |
| "ZinengTang/llava-lora-spatial" | |
| ) | |
| from PIL import Image | |
| init_prompt_instruct = "Describe the location of the blue sphere relative to the environment features." | |
| conversation = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "text", "text": init_prompt_instruct}, | |
| {"type": "image"}, # This will be replaced with the actual image | |
| ], | |
| }, | |
| ] | |
| speaker_image = Image.open('your_image_path') | |
| prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) | |
| # print(prompt) | |
| # Process the input image and prompt | |
| inputs = processor( | |
| images=speaker_image, | |
| text=prompt, | |
| return_tensors="pt", | |
| max_length=256, | |
| ).to('cuda') | |
| with torch.no_grad(): | |
| generated = model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| pixel_values=inputs["pixel_values"], | |
| max_length=512, | |
| num_beams=1, | |
| do_sample=True, | |
| temperature=0.7 | |
| ) | |
| generated_message = processor.batch_decode( | |
| generated, | |
| skip_special_tokens=True | |
| ) | |
| print(generated_message) | |
| generated_message = generated_message[0].split('ASSISTANT: ')[-1][:100] | |
| ``` | |