CanerDedeoglu commited on
Commit
76edc19
·
verified ·
1 Parent(s): dd67cb7

Upload predict (1).py

Browse files
Files changed (1) hide show
  1. predict (1).py +155 -0
predict (1).py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
4
+ from llava.conversation import conv_templates, SeparatorStyle
5
+ from llava.model.builder import load_pretrained_model
6
+ from llava.utils import disable_torch_init
7
+ from llava.mm_utils import tokenizer_image_token
8
+ from transformers.generation.streamers import TextIteratorStreamer
9
+
10
+ from PIL import Image
11
+
12
+ import requests
13
+ from io import BytesIO
14
+
15
+ from cog import BasePredictor, Input, Path, ConcatenateIterator
16
+ import time
17
+ import subprocess
18
+ from threading import Thread
19
+
20
+ import os
21
+ os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights"
22
+
23
+ # url for the weights mirror
24
+ REPLICATE_WEIGHTS_URL = "https://weights.replicate.delivery/default"
25
+ # files to download from the weights mirrors
26
+ weights = [
27
+ {
28
+ "dest": "liuhaotian/llava-v1.5-13b",
29
+ # git commit hash from huggingface
30
+ "src": "llava-v1.5-13b/006818fc465ebda4c003c0998674d9141d8d95f8",
31
+ "files": [
32
+ "config.json",
33
+ "generation_config.json",
34
+ "pytorch_model-00001-of-00003.bin",
35
+ "pytorch_model-00002-of-00003.bin",
36
+ "pytorch_model-00003-of-00003.bin",
37
+ "pytorch_model.bin.index.json",
38
+ "special_tokens_map.json",
39
+ "tokenizer.model",
40
+ "tokenizer_config.json",
41
+ ]
42
+ },
43
+ {
44
+ "dest": "openai/clip-vit-large-patch14-336",
45
+ "src": "clip-vit-large-patch14-336/ce19dc912ca5cd21c8a653c79e251e808ccabcd1",
46
+ "files": [
47
+ "config.json",
48
+ "preprocessor_config.json",
49
+ "pytorch_model.bin"
50
+ ],
51
+ }
52
+ ]
53
+
54
+ def download_json(url: str, dest: Path):
55
+ res = requests.get(url, allow_redirects=True)
56
+ if res.status_code == 200 and res.content:
57
+ with dest.open("wb") as f:
58
+ f.write(res.content)
59
+ else:
60
+ print(f"Failed to download {url}. Status code: {res.status_code}")
61
+
62
+ def download_weights(baseurl: str, basedest: str, files: list[str]):
63
+ basedest = Path(basedest)
64
+ start = time.time()
65
+ print("downloading to: ", basedest)
66
+ basedest.mkdir(parents=True, exist_ok=True)
67
+ for f in files:
68
+ dest = basedest / f
69
+ url = os.path.join(REPLICATE_WEIGHTS_URL, baseurl, f)
70
+ if not dest.exists():
71
+ print("downloading url: ", url)
72
+ if dest.suffix == ".json":
73
+ download_json(url, dest)
74
+ else:
75
+ subprocess.check_call(["pget", url, str(dest)], close_fds=False)
76
+ print("downloading took: ", time.time() - start)
77
+
78
+ class Predictor(BasePredictor):
79
+ def setup(self) -> None:
80
+ """Load the model into memory to make running multiple predictions efficient"""
81
+ for weight in weights:
82
+ download_weights(weight["src"], weight["dest"], weight["files"])
83
+ disable_torch_init()
84
+
85
+ self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model("liuhaotian/llava-v1.5-13b", model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False)
86
+
87
+ def predict(
88
+ self,
89
+ image: Path = Input(description="Input image"),
90
+ prompt: str = Input(description="Prompt to use for text generation"),
91
+ top_p: float = Input(description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", ge=0.0, le=1.0, default=1.0),
92
+ temperature: float = Input(description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic", default=0.2, ge=0.0),
93
+ max_tokens: int = Input(description="Maximum number of tokens to generate. A word is generally 2-3 tokens", default=1024, ge=0),
94
+ ) -> ConcatenateIterator[str]:
95
+ """Run a single prediction on the model"""
96
+
97
+ conv_mode = "llava_v1"
98
+ conv = conv_templates[conv_mode].copy()
99
+
100
+ image_data = load_image(str(image))
101
+ image_tensor = self.image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().cuda()
102
+
103
+ # loop start
104
+
105
+ # just one turn, always prepend image token
106
+ inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
107
+ conv.append_message(conv.roles[0], inp)
108
+
109
+ conv.append_message(conv.roles[1], None)
110
+ prompt = conv.get_prompt()
111
+
112
+ input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
113
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
114
+ keywords = [stop_str]
115
+ streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=20.0)
116
+
117
+ with torch.inference_mode():
118
+ thread = Thread(target=self.model.generate, kwargs=dict(
119
+ inputs=input_ids,
120
+ images=image_tensor,
121
+ do_sample=True,
122
+ temperature=temperature,
123
+ top_p=top_p,
124
+ max_new_tokens=max_tokens,
125
+ streamer=streamer,
126
+ use_cache=True))
127
+ thread.start()
128
+ # workaround: second-to-last token is always " "
129
+ # but we want to keep it if it's not the second-to-last token
130
+ prepend_space = False
131
+ for new_text in streamer:
132
+ if new_text == " ":
133
+ prepend_space = True
134
+ continue
135
+ if new_text.endswith(stop_str):
136
+ new_text = new_text[:-len(stop_str)].strip()
137
+ prepend_space = False
138
+ elif prepend_space:
139
+ new_text = " " + new_text
140
+ prepend_space = False
141
+ if len(new_text):
142
+ yield new_text
143
+ if prepend_space:
144
+ yield " "
145
+ thread.join()
146
+
147
+
148
+ def load_image(image_file):
149
+ if image_file.startswith('http') or image_file.startswith('https'):
150
+ response = requests.get(image_file)
151
+ image = Image.open(BytesIO(response.content)).convert('RGB')
152
+ else:
153
+ image = Image.open(image_file).convert('RGB')
154
+ return image
155
+