File size: 1,676 Bytes
5b42a0e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | """
Base model interface for LLM interactions.
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
import logging
logger = logging.getLogger(__name__)
class BaseModel(ABC):
"""Abstract base class for LLM models."""
def __init__(self, model_name: str, **kwargs):
self.model_name = model_name
self.name = model_name
@abstractmethod
def generate(
self,
prompt: str,
max_tokens: int = 512,
temperature: float = 0.7,
**kwargs
) -> str:
"""Generate a response from the model."""
pass
@abstractmethod
def generate_batch(
self,
prompts: List[str],
max_tokens: int = 512,
temperature: float = 0.7,
**kwargs
) -> List[str]:
"""Generate responses for a batch of prompts."""
pass
def wrap_as_chat_message(self, content: str, role: str = "user") -> Dict[str, str]:
"""Wrap content as a chat message."""
return {"role": role, "content": content}
def format_chat_messages(self, messages: List[Dict[str, str]]) -> str:
"""Format chat messages into a prompt string."""
formatted = ""
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if role == "system":
formatted += f"System: {content}\n\n"
elif role == "user":
formatted += f"User: {content}\n\n"
elif role == "assistant":
formatted += f"Assistant: {content}\n\n"
formatted += "Assistant:"
return formatted
|