CrystalRaindropsFall commited on
Commit
8982a79
·
verified ·
1 Parent(s): da79c1a

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +35 -0
README.md CHANGED
@@ -16,3 +16,38 @@ This is a pruned version of [LLaVA-1.5-7b](https://huggingface.co/llava-hf/llava
16
  - **Sparsity**: 30%
17
 
18
  This model was pruned to improve efficiency while maintaining performance.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  - **Sparsity**: 30%
17
 
18
  This model was pruned to improve efficiency while maintaining performance.
19
+
20
+ ## Usage
21
+
22
+ Since this model was pruned structurally, the architecture remains compatible with the standard `LlavaForConditionalGeneration` class. However, you should use the processor from the base model to ensure correct input preprocessing.
23
+
24
+ ```python
25
+ from transformers import AutoProcessor, LlavaForConditionalGeneration
26
+ import torch
27
+
28
+ model_id = "CrystalRaindropsFall/llava-heads-30pct"
29
+ base_model_id = "llava-hf/llava-1.5-7b-hf"
30
+
31
+ # 1. Load the processor from the base model
32
+ processor = AutoProcessor.from_pretrained(base_model_id)
33
+
34
+ # 2. Load the pruned model
35
+ model = LlavaForConditionalGeneration.from_pretrained(
36
+ model_id,
37
+ torch_dtype=torch.float16,
38
+ device_map="auto"
39
+ )
40
+
41
+ # Example inference
42
+ from PIL import Image
43
+ import requests
44
+
45
+ url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_logo.png?raw=true"
46
+ image = Image.open(requests.get(url, stream=True).raw)
47
+ prompt = "USER: <image>\nWhat is shown in this image?\nASSISTANT:"
48
+
49
+ inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, model.dtype)
50
+
51
+ output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
52
+ print(processor.decode(output[0], skip_special_tokens=True))
53
+ ```