File size: 4,321 Bytes
593f474
 
76edc19
593f474
 
 
 
 
 
 
76edc19
593f474
76edc19
593f474
76edc19
593f474
76edc19
 
 
593f474
 
76edc19
593f474
 
 
 
 
 
 
 
 
 
 
76edc19
 
 
593f474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76edc19
593f474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76edc19
593f474
 
 
 
 
 
 
 
76edc19
593f474
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Cog prediction script for the PULSE ECG model.

This module defines a ``Predictor`` class compatible with the Replicate
Cog framework.  It delegates model loading and inference to the
``EndpointHandler`` defined in ``handler.py``.  The predictor exposes a
simple ``predict`` method that accepts an image and a prompt, along with
optional sampling parameters.  The response is the generated text
answer from the model.
"""

from typing import Optional

from cog import BasePredictor, Input, Path

from handler import EndpointHandler


class Predictor(BasePredictor):
    """Cog predictor for the PULSE ECG model."""

    def setup(self) -> None:
        """Load the model on startup.

        Instantiates the ``EndpointHandler``.  The underlying model
        weights and vision tower are loaded during the handler's
        initialisation; this only happens once when the Cog server
        starts.
        """
        # Instantiate the handler.  Any environment variables
        # controlling model selection (e.g. ``HF_MODEL_ID`` or
        # ``PULSE_MODEL_REPO``) should be set before Cog starts.
        self.handler = EndpointHandler()

    def predict(
        self,
        image: Path = Input(description="Input ECG image file"),
        prompt: str = Input(description="Question to ask about the ECG"),
        temperature: float = Input(
            description="Randomness of generation; 0 for deterministic outputs",
            default=0.0,
            ge=0.0,
        ),
        top_p: float = Input(
            description="Nucleus sampling parameter; consider tokens in the top p cumulative probability",
            default=0.9,
            ge=0.0,
            le=1.0,
        ),
        max_tokens: int = Input(
            description="Maximum number of new tokens to generate",
            default=512,
            ge=0,
        ),
        repetition_penalty: float = Input(
            description="Penalise repetition; 1.0 means no penalty",
            default=1.0,
            ge=0.0,
        ),
        conv_mode: Optional[str] = Input(
            description="Override the conversation template (e.g. 'llava_v1')",
            default=None,
        ),
    ) -> str:
        """Generate a textual response for an ECG image and prompt.

        Parameters
        ----------
        image: Path
            Path to the input image file.  Cog will save uploaded
            images to a temporary location and pass the path here.
        prompt: str
            The question to ask about the ECG image.
        temperature: float
            Sampling temperature; higher values yield more random
            results.
        top_p: float
            Top-p (nucleus) sampling; lower values focus on more
            likely tokens.
        max_tokens: int
            Maximum number of tokens to generate beyond the prompt.
        repetition_penalty: float
            Penalty for repeating tokens; values >1.0 discourage
            repetition.
        conv_mode: Optional[str]
            Optional conversation template override.  If provided, the
            handler will use this template instead of inferring one
            from the model name.

        Returns
        -------
        str
            The generated answer from the model.
        """
        # Prepare the inputs for the handler.  Note: the handler expects
        # ``max_new_tokens`` rather than ``max_tokens`` for the length of
        # the generated sequence.
        event = {
            "image": str(image),
            "prompt": prompt,
            "temperature": temperature,
            "top_p": top_p,
            "max_new_tokens": max_tokens,
            "repetition_penalty": repetition_penalty,
        }
        if conv_mode:
            event["conv_mode"] = conv_mode

        # Invoke the handler.  The handler returns a dictionary which
        # includes either a ``generated_text`` key on success or an
        # ``error`` key on failure.
        result = self.handler(event)
        if isinstance(result, dict):
            if "error" in result:
                raise ValueError(result["error"])
            return result.get("generated_text", result.get("answer", ""))

        # If the handler returned a raw string (older versions), just
        # return it directly.
        return str(result)