""" Cog prediction script for the PULSE ECG model. This module defines a ``Predictor`` class compatible with the Replicate Cog framework. It delegates model loading and inference to the ``EndpointHandler`` defined in ``handler.py``. The predictor exposes a simple ``predict`` method that accepts an image and a prompt, along with optional sampling parameters. The response is the generated text answer from the model. """ from typing import Optional from cog import BasePredictor, Input, Path from handler import EndpointHandler class Predictor(BasePredictor): """Cog predictor for the PULSE ECG model.""" def setup(self) -> None: """Load the model on startup. Instantiates the ``EndpointHandler``. The underlying model weights and vision tower are loaded during the handler's initialisation; this only happens once when the Cog server starts. """ # Instantiate the handler. Any environment variables # controlling model selection (e.g. ``HF_MODEL_ID`` or # ``PULSE_MODEL_REPO``) should be set before Cog starts. self.handler = EndpointHandler() def predict( self, image: Path = Input(description="Input ECG image file"), prompt: str = Input(description="Question to ask about the ECG"), temperature: float = Input( description="Randomness of generation; 0 for deterministic outputs", default=0.0, ge=0.0, ), top_p: float = Input( description="Nucleus sampling parameter; consider tokens in the top p cumulative probability", default=0.9, ge=0.0, le=1.0, ), max_tokens: int = Input( description="Maximum number of new tokens to generate", default=512, ge=0, ), repetition_penalty: float = Input( description="Penalise repetition; 1.0 means no penalty", default=1.0, ge=0.0, ), conv_mode: Optional[str] = Input( description="Override the conversation template (e.g. 'llava_v1')", default=None, ), ) -> str: """Generate a textual response for an ECG image and prompt. Parameters ---------- image: Path Path to the input image file. Cog will save uploaded images to a temporary location and pass the path here. prompt: str The question to ask about the ECG image. temperature: float Sampling temperature; higher values yield more random results. top_p: float Top-p (nucleus) sampling; lower values focus on more likely tokens. max_tokens: int Maximum number of tokens to generate beyond the prompt. repetition_penalty: float Penalty for repeating tokens; values >1.0 discourage repetition. conv_mode: Optional[str] Optional conversation template override. If provided, the handler will use this template instead of inferring one from the model name. Returns ------- str The generated answer from the model. """ # Prepare the inputs for the handler. Note: the handler expects # ``max_new_tokens`` rather than ``max_tokens`` for the length of # the generated sequence. event = { "image": str(image), "prompt": prompt, "temperature": temperature, "top_p": top_p, "max_new_tokens": max_tokens, "repetition_penalty": repetition_penalty, } if conv_mode: event["conv_mode"] = conv_mode # Invoke the handler. The handler returns a dictionary which # includes either a ``generated_text`` key on success or an # ``error`` key on failure. result = self.handler(event) if isinstance(result, dict): if "error" in result: raise ValueError(result["error"]) return result.get("generated_text", result.get("answer", "")) # If the handler returned a raw string (older versions), just # return it directly. return str(result)