salmonaude.ai / predict.py
NoCritics's picture
Upload folder using huggingface_hub
857620c verified
import os
from typing import Optional
from cog import BasePredictor, Input, Path
from llama_cpp import Llama
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load the model into memory"""
model_path = "monad-mistral-7b.gguf"
# Download model if not present (Replicate will cache this)
if not os.path.exists(model_path):
print(f"Model not found at {model_path}")
# Replicate will handle model file placement
self.llm = Llama(
model_path=model_path,
n_ctx=4096, # Context window
n_threads=8, # CPU threads
n_gpu_layers=-1, # Use all GPU layers
verbose=False
)
# Default generation parameters
self.default_params = {
"max_tokens": 512,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"repeat_penalty": 1.1
}
def predict(
self,
prompt: str = Input(
description="Input prompt",
default="What is Monad blockchain?"
),
system_prompt: str = Input(
description="System prompt to guide the model's behavior",
default="You are an expert on Monad blockchain technology. Provide accurate, helpful information about Monad's architecture, ecosystem, and capabilities."
),
max_tokens: int = Input(
description="Maximum number of tokens to generate",
default=512,
ge=1,
le=4096
),
temperature: float = Input(
description="Temperature for sampling",
default=0.7,
ge=0.1,
le=2.0
),
top_p: float = Input(
description="Top-p sampling parameter",
default=0.9,
ge=0.1,
le=1.0
),
top_k: int = Input(
description="Top-k sampling parameter",
default=40,
ge=1,
le=100
),
repeat_penalty: float = Input(
description="Penalty for repeated tokens",
default=1.1,
ge=1.0,
le=2.0
),
seed: int = Input(
description="Random seed for reproducibility",
default=-1
)
) -> str:
"""Run inference on the model"""
# Format prompt with Mistral template
if system_prompt:
formatted_prompt = f"[INST] {system_prompt}\n\n{prompt} [/INST]"
else:
formatted_prompt = f"[INST] {prompt} [/INST]"
# Set seed if provided
if seed > 0:
import random
import numpy as np
random.seed(seed)
np.random.seed(seed)
# Generate response
output = self.llm(
formatted_prompt,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
repeat_penalty=repeat_penalty,
stop=["</s>", "[INST]", "[/INST]"],
echo=False
)
return output['choices'][0]['text'].strip()