Rapid_ECG / predict.py
CanerDedeoglu's picture
Update predict.py
593f474 verified
"""
Cog prediction script for the PULSE ECG model.
This module defines a ``Predictor`` class compatible with the Replicate
Cog framework. It delegates model loading and inference to the
``EndpointHandler`` defined in ``handler.py``. The predictor exposes a
simple ``predict`` method that accepts an image and a prompt, along with
optional sampling parameters. The response is the generated text
answer from the model.
"""
from typing import Optional
from cog import BasePredictor, Input, Path
from handler import EndpointHandler
class Predictor(BasePredictor):
"""Cog predictor for the PULSE ECG model."""
def setup(self) -> None:
"""Load the model on startup.
Instantiates the ``EndpointHandler``. The underlying model
weights and vision tower are loaded during the handler's
initialisation; this only happens once when the Cog server
starts.
"""
# Instantiate the handler. Any environment variables
# controlling model selection (e.g. ``HF_MODEL_ID`` or
# ``PULSE_MODEL_REPO``) should be set before Cog starts.
self.handler = EndpointHandler()
def predict(
self,
image: Path = Input(description="Input ECG image file"),
prompt: str = Input(description="Question to ask about the ECG"),
temperature: float = Input(
description="Randomness of generation; 0 for deterministic outputs",
default=0.0,
ge=0.0,
),
top_p: float = Input(
description="Nucleus sampling parameter; consider tokens in the top p cumulative probability",
default=0.9,
ge=0.0,
le=1.0,
),
max_tokens: int = Input(
description="Maximum number of new tokens to generate",
default=512,
ge=0,
),
repetition_penalty: float = Input(
description="Penalise repetition; 1.0 means no penalty",
default=1.0,
ge=0.0,
),
conv_mode: Optional[str] = Input(
description="Override the conversation template (e.g. 'llava_v1')",
default=None,
),
) -> str:
"""Generate a textual response for an ECG image and prompt.
Parameters
----------
image: Path
Path to the input image file. Cog will save uploaded
images to a temporary location and pass the path here.
prompt: str
The question to ask about the ECG image.
temperature: float
Sampling temperature; higher values yield more random
results.
top_p: float
Top-p (nucleus) sampling; lower values focus on more
likely tokens.
max_tokens: int
Maximum number of tokens to generate beyond the prompt.
repetition_penalty: float
Penalty for repeating tokens; values >1.0 discourage
repetition.
conv_mode: Optional[str]
Optional conversation template override. If provided, the
handler will use this template instead of inferring one
from the model name.
Returns
-------
str
The generated answer from the model.
"""
# Prepare the inputs for the handler. Note: the handler expects
# ``max_new_tokens`` rather than ``max_tokens`` for the length of
# the generated sequence.
event = {
"image": str(image),
"prompt": prompt,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_tokens,
"repetition_penalty": repetition_penalty,
}
if conv_mode:
event["conv_mode"] = conv_mode
# Invoke the handler. The handler returns a dictionary which
# includes either a ``generated_text`` key on success or an
# ``error`` key on failure.
result = self.handler(event)
if isinstance(result, dict):
if "error" in result:
raise ValueError(result["error"])
return result.get("generated_text", result.get("answer", ""))
# If the handler returned a raw string (older versions), just
# return it directly.
return str(result)