CanerDedeoglu commited on
Commit
593f474
·
verified ·
1 Parent(s): 8d66e99

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +106 -141
predict.py CHANGED
@@ -1,155 +1,120 @@
1
- import torch
 
2
 
3
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
4
- from llava.conversation import conv_templates, SeparatorStyle
5
- from llava.model.builder import load_pretrained_model
6
- from llava.utils import disable_torch_init
7
- from llava.mm_utils import tokenizer_image_token
8
- from transformers.generation.streamers import TextIteratorStreamer
 
9
 
10
- from PIL import Image
11
 
12
- import requests
13
- from io import BytesIO
14
 
15
- from cog import BasePredictor, Input, Path, ConcatenateIterator
16
- import time
17
- import subprocess
18
- from threading import Thread
19
 
20
- import os
21
- os.environ["HUGGINGFACE_HUB_CACHE"] = os.getcwd() + "/weights"
22
-
23
- # url for the weights mirror
24
- REPLICATE_WEIGHTS_URL = "https://weights.replicate.delivery/default"
25
- # files to download from the weights mirrors
26
- weights = [
27
- {
28
- "dest": "liuhaotian/llava-v1.5-13b",
29
- # git commit hash from huggingface
30
- "src": "llava-v1.5-13b/006818fc465ebda4c003c0998674d9141d8d95f8",
31
- "files": [
32
- "config.json",
33
- "generation_config.json",
34
- "pytorch_model-00001-of-00003.bin",
35
- "pytorch_model-00002-of-00003.bin",
36
- "pytorch_model-00003-of-00003.bin",
37
- "pytorch_model.bin.index.json",
38
- "special_tokens_map.json",
39
- "tokenizer.model",
40
- "tokenizer_config.json",
41
- ]
42
- },
43
- {
44
- "dest": "openai/clip-vit-large-patch14-336",
45
- "src": "clip-vit-large-patch14-336/ce19dc912ca5cd21c8a653c79e251e808ccabcd1",
46
- "files": [
47
- "config.json",
48
- "preprocessor_config.json",
49
- "pytorch_model.bin"
50
- ],
51
- }
52
- ]
53
-
54
- def download_json(url: str, dest: Path):
55
- res = requests.get(url, allow_redirects=True)
56
- if res.status_code == 200 and res.content:
57
- with dest.open("wb") as f:
58
- f.write(res.content)
59
- else:
60
- print(f"Failed to download {url}. Status code: {res.status_code}")
61
-
62
- def download_weights(baseurl: str, basedest: str, files: list[str]):
63
- basedest = Path(basedest)
64
- start = time.time()
65
- print("downloading to: ", basedest)
66
- basedest.mkdir(parents=True, exist_ok=True)
67
- for f in files:
68
- dest = basedest / f
69
- url = os.path.join(REPLICATE_WEIGHTS_URL, baseurl, f)
70
- if not dest.exists():
71
- print("downloading url: ", url)
72
- if dest.suffix == ".json":
73
- download_json(url, dest)
74
- else:
75
- subprocess.check_call(["pget", url, str(dest)], close_fds=False)
76
- print("downloading took: ", time.time() - start)
77
 
78
  class Predictor(BasePredictor):
 
 
79
  def setup(self) -> None:
80
- """Load the model into memory to make running multiple predictions efficient"""
81
- for weight in weights:
82
- download_weights(weight["src"], weight["dest"], weight["files"])
83
- disable_torch_init()
84
-
85
- self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model("liuhaotian/llava-v1.5-13b", model_name="llava-v1.5-13b", model_base=None, load_8bit=False, load_4bit=False)
 
 
 
 
 
86
 
87
  def predict(
88
  self,
89
- image: Path = Input(description="Input image"),
90
- prompt: str = Input(description="Prompt to use for text generation"),
91
- top_p: float = Input(description="When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens", ge=0.0, le=1.0, default=1.0),
92
- temperature: float = Input(description="Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic", default=0.2, ge=0.0),
93
- max_tokens: int = Input(description="Maximum number of tokens to generate. A word is generally 2-3 tokens", default=1024, ge=0),
94
- ) -> ConcatenateIterator[str]:
95
- """Run a single prediction on the model"""
96
-
97
- conv_mode = "llava_v1"
98
- conv = conv_templates[conv_mode].copy()
99
-
100
- image_data = load_image(str(image))
101
- image_tensor = self.image_processor.preprocess(image_data, return_tensors='pt')['pixel_values'].half().cuda()
102
-
103
- # loop start
104
-
105
- # just one turn, always prepend image token
106
- inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt
107
- conv.append_message(conv.roles[0], inp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- conv.append_message(conv.roles[1], None)
110
- prompt = conv.get_prompt()
111
-
112
- input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
113
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
114
- keywords = [stop_str]
115
- streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True, timeout=20.0)
116
-
117
- with torch.inference_mode():
118
- thread = Thread(target=self.model.generate, kwargs=dict(
119
- inputs=input_ids,
120
- images=image_tensor,
121
- do_sample=True,
122
- temperature=temperature,
123
- top_p=top_p,
124
- max_new_tokens=max_tokens,
125
- streamer=streamer,
126
- use_cache=True))
127
- thread.start()
128
- # workaround: second-to-last token is always " "
129
- # but we want to keep it if it's not the second-to-last token
130
- prepend_space = False
131
- for new_text in streamer:
132
- if new_text == " ":
133
- prepend_space = True
134
- continue
135
- if new_text.endswith(stop_str):
136
- new_text = new_text[:-len(stop_str)].strip()
137
- prepend_space = False
138
- elif prepend_space:
139
- new_text = " " + new_text
140
- prepend_space = False
141
- if len(new_text):
142
- yield new_text
143
- if prepend_space:
144
- yield " "
145
- thread.join()
146
-
147
 
148
- def load_image(image_file):
149
- if image_file.startswith('http') or image_file.startswith('https'):
150
- response = requests.get(image_file)
151
- image = Image.open(BytesIO(response.content)).convert('RGB')
152
- else:
153
- image = Image.open(image_file).convert('RGB')
154
- return image
 
155
 
 
 
 
 
1
+ """
2
+ Cog prediction script for the PULSE ECG model.
3
 
4
+ This module defines a ``Predictor`` class compatible with the Replicate
5
+ Cog framework. It delegates model loading and inference to the
6
+ ``EndpointHandler`` defined in ``handler.py``. The predictor exposes a
7
+ simple ``predict`` method that accepts an image and a prompt, along with
8
+ optional sampling parameters. The response is the generated text
9
+ answer from the model.
10
+ """
11
 
12
+ from typing import Optional
13
 
14
+ from cog import BasePredictor, Input, Path
 
15
 
16
+ from handler import EndpointHandler
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  class Predictor(BasePredictor):
20
+ """Cog predictor for the PULSE ECG model."""
21
+
22
  def setup(self) -> None:
23
+ """Load the model on startup.
24
+
25
+ Instantiates the ``EndpointHandler``. The underlying model
26
+ weights and vision tower are loaded during the handler's
27
+ initialisation; this only happens once when the Cog server
28
+ starts.
29
+ """
30
+ # Instantiate the handler. Any environment variables
31
+ # controlling model selection (e.g. ``HF_MODEL_ID`` or
32
+ # ``PULSE_MODEL_REPO``) should be set before Cog starts.
33
+ self.handler = EndpointHandler()
34
 
35
  def predict(
36
  self,
37
+ image: Path = Input(description="Input ECG image file"),
38
+ prompt: str = Input(description="Question to ask about the ECG"),
39
+ temperature: float = Input(
40
+ description="Randomness of generation; 0 for deterministic outputs",
41
+ default=0.0,
42
+ ge=0.0,
43
+ ),
44
+ top_p: float = Input(
45
+ description="Nucleus sampling parameter; consider tokens in the top p cumulative probability",
46
+ default=0.9,
47
+ ge=0.0,
48
+ le=1.0,
49
+ ),
50
+ max_tokens: int = Input(
51
+ description="Maximum number of new tokens to generate",
52
+ default=512,
53
+ ge=0,
54
+ ),
55
+ repetition_penalty: float = Input(
56
+ description="Penalise repetition; 1.0 means no penalty",
57
+ default=1.0,
58
+ ge=0.0,
59
+ ),
60
+ conv_mode: Optional[str] = Input(
61
+ description="Override the conversation template (e.g. 'llava_v1')",
62
+ default=None,
63
+ ),
64
+ ) -> str:
65
+ """Generate a textual response for an ECG image and prompt.
66
+
67
+ Parameters
68
+ ----------
69
+ image: Path
70
+ Path to the input image file. Cog will save uploaded
71
+ images to a temporary location and pass the path here.
72
+ prompt: str
73
+ The question to ask about the ECG image.
74
+ temperature: float
75
+ Sampling temperature; higher values yield more random
76
+ results.
77
+ top_p: float
78
+ Top-p (nucleus) sampling; lower values focus on more
79
+ likely tokens.
80
+ max_tokens: int
81
+ Maximum number of tokens to generate beyond the prompt.
82
+ repetition_penalty: float
83
+ Penalty for repeating tokens; values >1.0 discourage
84
+ repetition.
85
+ conv_mode: Optional[str]
86
+ Optional conversation template override. If provided, the
87
+ handler will use this template instead of inferring one
88
+ from the model name.
89
 
90
+ Returns
91
+ -------
92
+ str
93
+ The generated answer from the model.
94
+ """
95
+ # Prepare the inputs for the handler. Note: the handler expects
96
+ # ``max_new_tokens`` rather than ``max_tokens`` for the length of
97
+ # the generated sequence.
98
+ event = {
99
+ "image": str(image),
100
+ "prompt": prompt,
101
+ "temperature": temperature,
102
+ "top_p": top_p,
103
+ "max_new_tokens": max_tokens,
104
+ "repetition_penalty": repetition_penalty,
105
+ }
106
+ if conv_mode:
107
+ event["conv_mode"] = conv_mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ # Invoke the handler. The handler returns a dictionary which
110
+ # includes either a ``generated_text`` key on success or an
111
+ # ``error`` key on failure.
112
+ result = self.handler(event)
113
+ if isinstance(result, dict):
114
+ if "error" in result:
115
+ raise ValueError(result["error"])
116
+ return result.get("generated_text", result.get("answer", ""))
117
 
118
+ # If the handler returned a raw string (older versions), just
119
+ # return it directly.
120
+ return str(result)