Nexa_Labs / agent /client_llm.py
Allanatrix's picture
Upload 57 files
d8328bf verified
"""Client utilities for interacting with the NexaSci Assistant LLM."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Sequence
import torch
import yaml
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
@dataclass(frozen=True)
class Message:
"""Represents a single conversational turn."""
role: str
content: str
@dataclass(frozen=True)
class ModelConfig:
"""Configuration for loading the NexaSci model."""
base_repo: str
merged_path: Optional[str]
adapter_path: Optional[str]
backend: str
torch_dtype: Optional[str]
trust_remote_code: bool
@dataclass(frozen=True)
class GenerationSettings:
"""Settings governing text generation."""
max_new_tokens: int
temperature: float
top_p: float
repetition_penalty: float
tool_prefix: str
tool_suffix: str
final_prefix: str
final_suffix: str
stop_sequences: Sequence[str]
system_prompt: str
class NexaSciModelClient:
"""High-level client for loading and querying the NexaSci Assistant."""
def __init__(self, config_path: Path | str = "agent/config.yaml", lazy_load: bool = False) -> None:
"""Initialise the client by loading configuration and model weights.
Parameters
----------
config_path:
Path to the agent configuration YAML file.
lazy_load:
If True, delay model loading until first generation call.
"""
self._config_path = Path(config_path)
if not self._config_path.exists():
raise FileNotFoundError(f"Configuration file not found at {self._config_path}")
raw_cfg = _load_yaml(self._config_path)
self.model_config = _parse_model_config(raw_cfg["model"])
self.generation_settings = _parse_generation_settings(raw_cfg["generation"])
self._tooling_config = raw_cfg.get("tooling", {})
self._tokenizer: Any | None = None
self._model: AutoModelForCausalLM | None = None
self._lazy_load = lazy_load
if not lazy_load:
print("Loading tokenizer and model...")
self._tokenizer = self._load_tokenizer()
self._model = self._load_model()
print("✓ Model loaded")
else:
print("Model will be loaded on first generation call")
@property
def tokenizer(self) -> Any:
"""Lazy-load tokenizer if needed."""
if self._tokenizer is None:
self._tokenizer = self._load_tokenizer()
return self._tokenizer
@property
def model(self) -> AutoModelForCausalLM:
"""Lazy-load model if needed."""
if self._model is None:
print("Loading model (this may take 30-60 seconds)...")
self._model = self._load_model()
print("✓ Model loaded")
return self._model
@property
def available_tools(self) -> Sequence[str]:
"""Return the list of tool identifiers declared in configuration."""
return tuple(self._tooling_config.get("available_tools", []))
def _resolve_model_path(self, path: Optional[str]) -> str:
"""Resolve a model path, handling relative paths relative to config file."""
if path is None:
return self.model_config.base_repo
path_obj = Path(path)
if path_obj.is_absolute():
return str(path_obj)
# Resolve relative to config file's parent directory (project root)
config_dir = self._config_path.parent
resolved = (config_dir / path_obj).resolve()
return str(resolved)
def _load_tokenizer(self) -> Any:
"""Load the tokenizer for the configured model."""
source = self._resolve_model_path(self.model_config.merged_path)
print(f" Loading tokenizer from: {source}")
tokenizer = AutoTokenizer.from_pretrained(
source,
trust_remote_code=self.model_config.trust_remote_code,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print(" ✓ Tokenizer loaded")
return tokenizer
def _load_model(self) -> AutoModelForCausalLM:
"""Load the base or merged model for inference."""
source = self._resolve_model_path(self.model_config.merged_path)
torch_dtype = _resolve_torch_dtype(self.model_config.torch_dtype)
print(f" Loading model from: {source}")
print(f" Using dtype: {torch_dtype}")
# Check CUDA availability
if torch.cuda.is_available():
print(f" ✓ CUDA available: {torch.cuda.get_device_name(0)}")
print(f" ✓ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.1f} GB")
device_map = "auto"
else:
print(" ⚠️ WARNING: CUDA not available! Model will load on CPU (very slow)")
print(" Check: 1) NVIDIA drivers installed, 2) PyTorch with CUDA support, 3) GPU visible")
device_map = None
print(f" This may take 30-60 seconds...")
model = AutoModelForCausalLM.from_pretrained(
source,
device_map=device_map,
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
trust_remote_code=self.model_config.trust_remote_code,
)
model.eval()
# Verify device
if torch.cuda.is_available():
model_device = next(model.parameters()).device
if model_device.type == "cuda":
print(f" ✓ Model loaded on GPU: {model_device}")
else:
print(f" ⚠️ WARNING: Model loaded on {model_device}, not GPU!")
else:
print(" ⚠️ Model loaded on CPU (will be very slow)")
return model
def build_chat_messages(self, messages: Iterable[Message]) -> List[Dict[str, str]]:
"""Convert internal message objects to the tokenizer chat format."""
formatted: List[Dict[str, str]] = []
system_present = any(message.role == "system" for message in messages)
if not system_present:
formatted.append(
{
"role": "system",
"content": self.generation_settings.system_prompt,
}
)
for message in messages:
formatted.append({"role": message.role, "content": message.content})
return formatted
def _format_prompt(self, messages: Sequence[Message]) -> str:
"""Format messages into a prompt string for Falcon models."""
parts = []
system_present = any(msg.role == "system" for msg in messages)
if not system_present:
parts.append(f"System: {self.generation_settings.system_prompt}")
for message in messages:
role = message.role.capitalize()
if role == "System":
parts.append(f"System: {message.content}")
elif role == "User":
parts.append(f"User: {message.content}")
elif role == "Assistant":
parts.append(f"Assistant: {message.content}")
elif role == "Tool":
parts.append(f"Tool: {message.content}")
parts.append("Assistant:")
return "\n\n".join(parts)
def generate(
self,
messages: Sequence[Message],
*,
max_new_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
) -> str:
"""Generate a response from the model given a message history."""
# Check if tokenizer has a chat template, otherwise format manually
if hasattr(self.tokenizer, "chat_template") and self.tokenizer.chat_template is not None:
chat_messages = self.build_chat_messages(messages)
inputs = self.tokenizer.apply_chat_template(
chat_messages,
tokenize=True,
return_tensors="pt",
add_generation_prompt=True,
)
# apply_chat_template returns tensor directly
if isinstance(inputs, torch.Tensor):
input_ids = inputs.to(self.model.device)
else:
input_ids = inputs["input_ids"].to(self.model.device)
else:
# Manual formatting for models without chat templates (e.g., Falcon)
prompt_text = self._format_prompt(messages)
tokenized = self.tokenizer(
prompt_text,
return_tensors="pt",
add_special_tokens=True,
)
input_ids = tokenized["input_ids"].to(self.model.device)
temp = temperature or self.generation_settings.temperature
top_p_val = top_p or self.generation_settings.top_p
# Enable sampling when temperature/top_p are used
do_sample = temp > 0.0 or top_p_val < 1.0
generation_config = GenerationConfig(
max_new_tokens=max_new_tokens or self.generation_settings.max_new_tokens,
temperature=temp,
top_p=top_p_val,
do_sample=do_sample,
repetition_penalty=self.generation_settings.repetition_penalty,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
)
# Ensure model is loaded
model = self.model
model_device = next(model.parameters()).device
if model_device.type == "cuda":
torch.cuda.empty_cache()
free_mem = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0)
free_gb = free_mem / (1024**3)
print(f" 💾 GPU memory: {free_gb:.1f} GB free")
if free_gb < 5:
print(f" ⚠️ Warning: Low GPU memory ({free_gb:.1f} GB)")
else:
print(f" ⚠️ WARNING: Model is on {model_device}, not GPU! Generation will be very slow.")
print(f" This is likely why it's freezing. Check CUDA installation.")
try:
print(f" 🚀 Starting generation (max {generation_config.max_new_tokens} tokens)...")
with torch.inference_mode():
output_ids = model.generate(
input_ids=input_ids,
generation_config=generation_config,
do_sample=do_sample,
)
print(f" ✓ Generation complete")
except torch.cuda.OutOfMemoryError as e:
torch.cuda.empty_cache()
raise RuntimeError(
f"GPU out of memory. Try: 1) Reduce max_new_tokens, 2) Use CPU, "
f"3) Close other GPU processes. Original error: {e}"
) from e
except RuntimeError as e:
if "out of memory" in str(e).lower():
torch.cuda.empty_cache()
raise RuntimeError(
f"GPU out of memory. Try reducing max_new_tokens or closing other processes."
) from e
raise
except Exception as e:
raise RuntimeError(f"Generation failed: {e}") from e
generated_ids = output_ids[0, input_ids.shape[-1] :]
generated_text = self.tokenizer.decode(
generated_ids,
skip_special_tokens=True,
)
return generated_text.strip()
def _parse_model_config(raw_config: Dict[str, Any]) -> ModelConfig:
"""Validate and coerce the raw model configuration."""
return ModelConfig(
base_repo=raw_config["base_repo"],
merged_path=raw_config.get("merged_path"),
adapter_path=raw_config.get("adapter_path"),
backend=raw_config.get("backend", "transformers"),
torch_dtype=raw_config.get("torch_dtype"),
trust_remote_code=bool(raw_config.get("trust_remote_code", True)),
)
def _parse_generation_settings(raw_config: Dict[str, Any]) -> GenerationSettings:
"""Validate and coerce generation settings."""
return GenerationSettings(
max_new_tokens=int(raw_config.get("max_new_tokens", 512)),
temperature=float(raw_config.get("temperature", 0.3)),
top_p=float(raw_config.get("top_p", 0.9)),
repetition_penalty=float(raw_config.get("repetition_penalty", 1.05)),
tool_prefix=str(raw_config.get("tool_prefix", "~~~toolcall")),
tool_suffix=str(raw_config.get("tool_suffix", "~~~")),
final_prefix=str(raw_config.get("final_prefix", "~~~final")),
final_suffix=str(raw_config.get("final_suffix", "~~~")),
stop_sequences=tuple(raw_config.get("stop_sequences", [])),
system_prompt=str(
raw_config.get(
"system_prompt",
"You are the NexaSci Assistant, a scientific research agent.",
)
),
)
def _resolve_torch_dtype(dtype_name: Optional[str]) -> Optional[torch.dtype]:
"""Map configuration dtype strings to torch dtypes."""
if dtype_name is None:
return None
normalised = dtype_name.strip().lower()
mapping = {
"float16": torch.float16,
"fp16": torch.float16,
"half": torch.float16,
"bfloat16": torch.bfloat16,
"bf16": torch.bfloat16,
"float32": torch.float32,
"fp32": torch.float32,
}
try:
return mapping[normalised]
except KeyError as exc:
raise ValueError(f"Unsupported torch dtype: {dtype_name}") from exc
def _load_yaml(path: Path) -> Dict[str, Any]:
"""Load and parse YAML configuration from disk."""
with path.open("r", encoding="utf-8") as handle:
return yaml.safe_load(handle)
__all__ = [
"GenerationSettings",
"Message",
"ModelConfig",
"NexaSciModelClient",
]