|
|
""" |
|
|
Cog prediction script for the PULSE ECG model. |
|
|
|
|
|
This module defines a ``Predictor`` class compatible with the Replicate |
|
|
Cog framework. It delegates model loading and inference to the |
|
|
``EndpointHandler`` defined in ``handler.py``. The predictor exposes a |
|
|
simple ``predict`` method that accepts an image and a prompt, along with |
|
|
optional sampling parameters. The response is the generated text |
|
|
answer from the model. |
|
|
""" |
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
from cog import BasePredictor, Input, Path |
|
|
|
|
|
from handler import EndpointHandler |
|
|
|
|
|
|
|
|
class Predictor(BasePredictor): |
|
|
"""Cog predictor for the PULSE ECG model.""" |
|
|
|
|
|
def setup(self) -> None: |
|
|
"""Load the model on startup. |
|
|
|
|
|
Instantiates the ``EndpointHandler``. The underlying model |
|
|
weights and vision tower are loaded during the handler's |
|
|
initialisation; this only happens once when the Cog server |
|
|
starts. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
self.handler = EndpointHandler() |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
image: Path = Input(description="Input ECG image file"), |
|
|
prompt: str = Input(description="Question to ask about the ECG"), |
|
|
temperature: float = Input( |
|
|
description="Randomness of generation; 0 for deterministic outputs", |
|
|
default=0.0, |
|
|
ge=0.0, |
|
|
), |
|
|
top_p: float = Input( |
|
|
description="Nucleus sampling parameter; consider tokens in the top p cumulative probability", |
|
|
default=0.9, |
|
|
ge=0.0, |
|
|
le=1.0, |
|
|
), |
|
|
max_tokens: int = Input( |
|
|
description="Maximum number of new tokens to generate", |
|
|
default=512, |
|
|
ge=0, |
|
|
), |
|
|
repetition_penalty: float = Input( |
|
|
description="Penalise repetition; 1.0 means no penalty", |
|
|
default=1.0, |
|
|
ge=0.0, |
|
|
), |
|
|
conv_mode: Optional[str] = Input( |
|
|
description="Override the conversation template (e.g. 'llava_v1')", |
|
|
default=None, |
|
|
), |
|
|
) -> str: |
|
|
"""Generate a textual response for an ECG image and prompt. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
image: Path |
|
|
Path to the input image file. Cog will save uploaded |
|
|
images to a temporary location and pass the path here. |
|
|
prompt: str |
|
|
The question to ask about the ECG image. |
|
|
temperature: float |
|
|
Sampling temperature; higher values yield more random |
|
|
results. |
|
|
top_p: float |
|
|
Top-p (nucleus) sampling; lower values focus on more |
|
|
likely tokens. |
|
|
max_tokens: int |
|
|
Maximum number of tokens to generate beyond the prompt. |
|
|
repetition_penalty: float |
|
|
Penalty for repeating tokens; values >1.0 discourage |
|
|
repetition. |
|
|
conv_mode: Optional[str] |
|
|
Optional conversation template override. If provided, the |
|
|
handler will use this template instead of inferring one |
|
|
from the model name. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
str |
|
|
The generated answer from the model. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
event = { |
|
|
"image": str(image), |
|
|
"prompt": prompt, |
|
|
"temperature": temperature, |
|
|
"top_p": top_p, |
|
|
"max_new_tokens": max_tokens, |
|
|
"repetition_penalty": repetition_penalty, |
|
|
} |
|
|
if conv_mode: |
|
|
event["conv_mode"] = conv_mode |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
result = self.handler(event) |
|
|
if isinstance(result, dict): |
|
|
if "error" in result: |
|
|
raise ValueError(result["error"]) |
|
|
return result.get("generated_text", result.get("answer", "")) |
|
|
|
|
|
|
|
|
|
|
|
return str(result) |