| from dataclasses import dataclass
from typing import Optional
import logging
import torch
from cpufeature import CPUFeature
from petals.constants import PUBLIC_INITIAL_PEERS
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("app.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
@dataclass
class ModelInfo:
repo: str
adapter: Optional[str] = None
MODELS = [
ModelInfo(repo="meta-llama/Llama-2-70b-chat-hf"),
ModelInfo(repo="stabilityai/StableBeluga2"),
ModelInfo(repo="enoch/llama-65b-hf"),
ModelInfo(repo="enoch/llama-65b-hf", adapter="timdettmers/guanaco-65b"),
ModelInfo(repo="bigscience/bloomz"),
ModelInfo(repo="roda-1"),
ModelInfo(repo="kubu-hai.model.h5-2", adapter="kubu-hai.model.mat-2"),
]
DEFAULT_MODEL_NAME = "enoch/llama-65b-hf"
INITIAL_PEERS = PUBLIC_INITIAL_PEERS
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
if DEVICE == "cuda":
TORCH_DTYPE = "auto"
elif CPUFeature["AVX512f"] and CPUFeature["OS_AVX512"]:
TORCH_DTYPE = torch.bfloat16
else:
TORCH_DTYPE = torch.float32
STEP_TIMEOUT = 10 * 60
MAX_SESSIONS = 50
logger.info("Configuration setup complete.")
def preprocess(data):
logger.debug("Preprocessing data")
return data
def postprocess(data):
logger.debug("Postprocessing data")
return data
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
return x
model = MyModel().to(DEVICE)
def hybrid_function(data):
data_cpu = data.to("cpu")
preprocessed_data = preprocess(data_cpu)
preprocessed_data = preprocessed_data.to(DEVICE)
output = model(preprocessed_data)
output_cpu = output.to("cpu")
result = postprocess(output_cpu)
return result
data = torch.randn(100, 10).to(DEVICE)
result = hybrid_function(data)
logger.info("Processing complete.")
|