| | import torch |
| |
|
| | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN |
| | from llava.conversation import conv_templates, SeparatorStyle |
| | from llava.model.builder import load_pretrained_model |
| | from llava.utils import disable_torch_init |
| | from llava.mm_utils import tokenizer_image_token |
| | from transformers.generation.streamers import TextIteratorStreamer |
| |
|
| | from PIL import Image |
| |
|
| | import requests |
| | from io import BytesIO |
| |
|
| | from cog import BasePredictor, Input, Path, ConcatenateIterator |
| | import time |
| | import subprocess |
| | from threading import Thread |
| |
|
| | import os |
| | os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights" |
| |
|
| | |
| | REPLICATE_WEIGHTS_URL = "https://weights.replicate.delivery/default" |
| | |
| | weights = [ |
| | { |
| | "dest": "liuhaotian/llava-v1.5-13b", |
| | |
| | "src": "llava-v1.5-13b/006818fc465ebda4c003c0998674d9141d8d95f8", |
| | "files": [ |
| | "config.json", |
| | "generation_config.json", |
| | "pytorch_model-00001-of-00003.bin", |
| | "pytorch_model-00002-of-00003.bin", |
| | "pytorch_model-00003-of-00003.bin", |
| | "pytorch_model.bin.index.json", |
| | "special_tokens_map.json", |
| | "tokenizer.model", |
| | "tokenizer_config.json", |
| | ] |
| | }, |
| | { |
| | "dest": "openai/clip-vit-large-patch14-336", |
| | "src": "clip-vit-large-patch14-336/ce19dc912ca5cd21c8a653c79e251e808ccabcd1", |
| | "files": [ |
| | "config.json", |
| | "preprocessor_config.json", |
| | "pytorch_model.bin" |
| | ], |
| | } |
| | ] |
| |
|
| | def download_json(url: str, dest: Path): |
| | res = requests.get(url, allow_redirects=True) |
| | if res.status_code == 200 and res.content: |
| | with dest.open("wb") as f: |
| | f.write(res.content) |
| | else: |
| | print(f"Failed to download {url}. Status code: {res.status_code}") |
| |
|
| | def download_weights(baseurl: str, basedest: str, files: list[str]): |
| | basedest = Path(basedest) |
| | start = time.time() |
| | print("downloading to: ", basedest) |
| | basedest.mkdir(parents=True, exist_ok=True) |
| | for f in files: |
| | dest = basedest / f |
| | url = os.path.join(REPLICATE_WEIGHTS_URL, baseurl, f) |
| | if not dest.exists(): |
| | print("downloading url: ", url) |
| | if dest.suffix == ".json": |
| | download_json(url, dest) |
| | else: |
| | subprocess.check_call(["pget", url, str(dest)], close_fds=False) |
| | print("downloading took: ", time.time() - start) |
| |
|
| | class Predictor(BasePredictor): |
| | def setup(self) -> None: |
| | """Load the model into memory to make running multiple predictions efficient""" |
| | for weight in weights: |
| | download_weights(weight["src"], weight["dest"], weight["files"]) |
| | disable_torch_init() |
| | |
| | self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model("liuhaotian/llava-v1.5-13b", model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False) |
| |
|
| | def predict( |
| | self, |
| | image: Path = Input(description="Input image"), |
| | prompt: str = Input(description="Prompt to use for text generation"), |
| | top_p: float = Input(description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", ge=0.0, le=1.0, default=1.0), |
| | temperature: float = Input(description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic", default=0.2, ge=0.0), |
| | max_tokens: int = Input(description="Maximum number of tokens to generate. A word is generally 2-3 tokens", default=1024, ge=0), |
| | ) -> ConcatenateIterator[str]: |
| | """Run a single prediction on the model""" |
| | |
| | conv_mode = "llava_v1" |
| | conv = conv_templates[conv_mode].copy() |
| | |
| | image_data = load_image(str(image)) |
| | image_tensor = self.image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().cuda() |
| | |
| | |
| | |
| | |
| | inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt |
| | conv.append_message(conv.roles[0], inp) |
| |
|
| | conv.append_message(conv.roles[1], None) |
| | prompt = conv.get_prompt() |
| | |
| | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() |
| | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
| | keywords = [stop_str] |
| | streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=20.0) |
| | |
| | with torch.inference_mode(): |
| | thread = Thread(target=self.model.generate, kwargs=dict( |
| | inputs=input_ids, |
| | images=image_tensor, |
| | do_sample=True, |
| | temperature=temperature, |
| | top_p=top_p, |
| | max_new_tokens=max_tokens, |
| | streamer=streamer, |
| | use_cache=True)) |
| | thread.start() |
| | |
| | |
| | prepend_space = False |
| | for new_text in streamer: |
| | if new_text == " ": |
| | prepend_space = True |
| | continue |
| | if new_text.endswith(stop_str): |
| | new_text = new_text[:-len(stop_str)].strip() |
| | prepend_space = False |
| | elif prepend_space: |
| | new_text = " " + new_text |
| | prepend_space = False |
| | if len(new_text): |
| | yield new_text |
| | if prepend_space: |
| | yield " " |
| | thread.join() |
| | |
| |
|
| | def load_image(image_file): |
| | if image_file.startswith('http') or image_file.startswith('https'): |
| | response = requests.get(image_file) |
| | image = Image.open(BytesIO(response.content)).convert('RGB') |
| | else: |
| | image = Image.open(image_file).convert('RGB') |
| | return image |
| |
|
| |
|