Spaces:
Paused
Paused
| """Client utilities for interacting with the NexaSci Assistant LLM.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Dict, Iterable, List, Optional, Sequence | |
| import torch | |
| import yaml | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | |
| class Message: | |
| """Represents a single conversational turn.""" | |
| role: str | |
| content: str | |
| class ModelConfig: | |
| """Configuration for loading the NexaSci model.""" | |
| base_repo: str | |
| merged_path: Optional[str] | |
| adapter_path: Optional[str] | |
| backend: str | |
| torch_dtype: Optional[str] | |
| trust_remote_code: bool | |
| class GenerationSettings: | |
| """Settings governing text generation.""" | |
| max_new_tokens: int | |
| temperature: float | |
| top_p: float | |
| repetition_penalty: float | |
| tool_prefix: str | |
| tool_suffix: str | |
| final_prefix: str | |
| final_suffix: str | |
| stop_sequences: Sequence[str] | |
| system_prompt: str | |
| class NexaSciModelClient: | |
| """High-level client for loading and querying the NexaSci Assistant.""" | |
| def __init__(self, config_path: Path | str = "agent/config.yaml", lazy_load: bool = False) -> None: | |
| """Initialise the client by loading configuration and model weights. | |
| Parameters | |
| ---------- | |
| config_path: | |
| Path to the agent configuration YAML file. | |
| lazy_load: | |
| If True, delay model loading until first generation call. | |
| """ | |
| self._config_path = Path(config_path) | |
| if not self._config_path.exists(): | |
| raise FileNotFoundError(f"Configuration file not found at {self._config_path}") | |
| raw_cfg = _load_yaml(self._config_path) | |
| self.model_config = _parse_model_config(raw_cfg["model"]) | |
| self.generation_settings = _parse_generation_settings(raw_cfg["generation"]) | |
| self._tooling_config = raw_cfg.get("tooling", {}) | |
| self._tokenizer: Any | None = None | |
| self._model: AutoModelForCausalLM | None = None | |
| self._lazy_load = lazy_load | |
| if not lazy_load: | |
| print("Loading tokenizer and model...") | |
| self._tokenizer = self._load_tokenizer() | |
| self._model = self._load_model() | |
| print("✓ Model loaded") | |
| else: | |
| print("Model will be loaded on first generation call") | |
| def tokenizer(self) -> Any: | |
| """Lazy-load tokenizer if needed.""" | |
| if self._tokenizer is None: | |
| self._tokenizer = self._load_tokenizer() | |
| return self._tokenizer | |
| def model(self) -> AutoModelForCausalLM: | |
| """Lazy-load model if needed.""" | |
| if self._model is None: | |
| print("Loading model (this may take 30-60 seconds)...") | |
| self._model = self._load_model() | |
| print("✓ Model loaded") | |
| return self._model | |
| def available_tools(self) -> Sequence[str]: | |
| """Return the list of tool identifiers declared in configuration.""" | |
| return tuple(self._tooling_config.get("available_tools", [])) | |
| def _resolve_model_path(self, path: Optional[str]) -> str: | |
| """Resolve a model path, handling relative paths relative to config file.""" | |
| if path is None: | |
| return self.model_config.base_repo | |
| path_obj = Path(path) | |
| if path_obj.is_absolute(): | |
| return str(path_obj) | |
| # Resolve relative to config file's parent directory (project root) | |
| config_dir = self._config_path.parent | |
| resolved = (config_dir / path_obj).resolve() | |
| return str(resolved) | |
| def _load_tokenizer(self) -> Any: | |
| """Load the tokenizer for the configured model.""" | |
| source = self._resolve_model_path(self.model_config.merged_path) | |
| print(f" Loading tokenizer from: {source}") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| source, | |
| trust_remote_code=self.model_config.trust_remote_code, | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print(" ✓ Tokenizer loaded") | |
| return tokenizer | |
| def _load_model(self) -> AutoModelForCausalLM: | |
| """Load the base or merged model for inference.""" | |
| source = self._resolve_model_path(self.model_config.merged_path) | |
| torch_dtype = _resolve_torch_dtype(self.model_config.torch_dtype) | |
| print(f" Loading model from: {source}") | |
| print(f" Using dtype: {torch_dtype}") | |
| # Check CUDA availability | |
| if torch.cuda.is_available(): | |
| print(f" ✓ CUDA available: {torch.cuda.get_device_name(0)}") | |
| print(f" ✓ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / (1024**3):.1f} GB") | |
| device_map = "auto" | |
| else: | |
| print(" ⚠️ WARNING: CUDA not available! Model will load on CPU (very slow)") | |
| print(" Check: 1) NVIDIA drivers installed, 2) PyTorch with CUDA support, 3) GPU visible") | |
| device_map = None | |
| print(f" This may take 30-60 seconds...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| source, | |
| device_map=device_map, | |
| torch_dtype=torch_dtype, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=self.model_config.trust_remote_code, | |
| ) | |
| model.eval() | |
| # Verify device | |
| if torch.cuda.is_available(): | |
| model_device = next(model.parameters()).device | |
| if model_device.type == "cuda": | |
| print(f" ✓ Model loaded on GPU: {model_device}") | |
| else: | |
| print(f" ⚠️ WARNING: Model loaded on {model_device}, not GPU!") | |
| else: | |
| print(" ⚠️ Model loaded on CPU (will be very slow)") | |
| return model | |
| def build_chat_messages(self, messages: Iterable[Message]) -> List[Dict[str, str]]: | |
| """Convert internal message objects to the tokenizer chat format.""" | |
| formatted: List[Dict[str, str]] = [] | |
| system_present = any(message.role == "system" for message in messages) | |
| if not system_present: | |
| formatted.append( | |
| { | |
| "role": "system", | |
| "content": self.generation_settings.system_prompt, | |
| } | |
| ) | |
| for message in messages: | |
| formatted.append({"role": message.role, "content": message.content}) | |
| return formatted | |
| def _format_prompt(self, messages: Sequence[Message]) -> str: | |
| """Format messages into a prompt string for Falcon models.""" | |
| parts = [] | |
| system_present = any(msg.role == "system" for msg in messages) | |
| if not system_present: | |
| parts.append(f"System: {self.generation_settings.system_prompt}") | |
| for message in messages: | |
| role = message.role.capitalize() | |
| if role == "System": | |
| parts.append(f"System: {message.content}") | |
| elif role == "User": | |
| parts.append(f"User: {message.content}") | |
| elif role == "Assistant": | |
| parts.append(f"Assistant: {message.content}") | |
| elif role == "Tool": | |
| parts.append(f"Tool: {message.content}") | |
| parts.append("Assistant:") | |
| return "\n\n".join(parts) | |
| def generate( | |
| self, | |
| messages: Sequence[Message], | |
| *, | |
| max_new_tokens: Optional[int] = None, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| ) -> str: | |
| """Generate a response from the model given a message history.""" | |
| # Check if tokenizer has a chat template, otherwise format manually | |
| if hasattr(self.tokenizer, "chat_template") and self.tokenizer.chat_template is not None: | |
| chat_messages = self.build_chat_messages(messages) | |
| inputs = self.tokenizer.apply_chat_template( | |
| chat_messages, | |
| tokenize=True, | |
| return_tensors="pt", | |
| add_generation_prompt=True, | |
| ) | |
| # apply_chat_template returns tensor directly | |
| if isinstance(inputs, torch.Tensor): | |
| input_ids = inputs.to(self.model.device) | |
| else: | |
| input_ids = inputs["input_ids"].to(self.model.device) | |
| else: | |
| # Manual formatting for models without chat templates (e.g., Falcon) | |
| prompt_text = self._format_prompt(messages) | |
| tokenized = self.tokenizer( | |
| prompt_text, | |
| return_tensors="pt", | |
| add_special_tokens=True, | |
| ) | |
| input_ids = tokenized["input_ids"].to(self.model.device) | |
| temp = temperature or self.generation_settings.temperature | |
| top_p_val = top_p or self.generation_settings.top_p | |
| # Enable sampling when temperature/top_p are used | |
| do_sample = temp > 0.0 or top_p_val < 1.0 | |
| generation_config = GenerationConfig( | |
| max_new_tokens=max_new_tokens or self.generation_settings.max_new_tokens, | |
| temperature=temp, | |
| top_p=top_p_val, | |
| do_sample=do_sample, | |
| repetition_penalty=self.generation_settings.repetition_penalty, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| ) | |
| # Ensure model is loaded | |
| model = self.model | |
| model_device = next(model.parameters()).device | |
| if model_device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| free_mem = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated(0) | |
| free_gb = free_mem / (1024**3) | |
| print(f" 💾 GPU memory: {free_gb:.1f} GB free") | |
| if free_gb < 5: | |
| print(f" ⚠️ Warning: Low GPU memory ({free_gb:.1f} GB)") | |
| else: | |
| print(f" ⚠️ WARNING: Model is on {model_device}, not GPU! Generation will be very slow.") | |
| print(f" This is likely why it's freezing. Check CUDA installation.") | |
| try: | |
| print(f" 🚀 Starting generation (max {generation_config.max_new_tokens} tokens)...") | |
| with torch.inference_mode(): | |
| output_ids = model.generate( | |
| input_ids=input_ids, | |
| generation_config=generation_config, | |
| do_sample=do_sample, | |
| ) | |
| print(f" ✓ Generation complete") | |
| except torch.cuda.OutOfMemoryError as e: | |
| torch.cuda.empty_cache() | |
| raise RuntimeError( | |
| f"GPU out of memory. Try: 1) Reduce max_new_tokens, 2) Use CPU, " | |
| f"3) Close other GPU processes. Original error: {e}" | |
| ) from e | |
| except RuntimeError as e: | |
| if "out of memory" in str(e).lower(): | |
| torch.cuda.empty_cache() | |
| raise RuntimeError( | |
| f"GPU out of memory. Try reducing max_new_tokens or closing other processes." | |
| ) from e | |
| raise | |
| except Exception as e: | |
| raise RuntimeError(f"Generation failed: {e}") from e | |
| generated_ids = output_ids[0, input_ids.shape[-1] :] | |
| generated_text = self.tokenizer.decode( | |
| generated_ids, | |
| skip_special_tokens=True, | |
| ) | |
| return generated_text.strip() | |
| def _parse_model_config(raw_config: Dict[str, Any]) -> ModelConfig: | |
| """Validate and coerce the raw model configuration.""" | |
| return ModelConfig( | |
| base_repo=raw_config["base_repo"], | |
| merged_path=raw_config.get("merged_path"), | |
| adapter_path=raw_config.get("adapter_path"), | |
| backend=raw_config.get("backend", "transformers"), | |
| torch_dtype=raw_config.get("torch_dtype"), | |
| trust_remote_code=bool(raw_config.get("trust_remote_code", True)), | |
| ) | |
| def _parse_generation_settings(raw_config: Dict[str, Any]) -> GenerationSettings: | |
| """Validate and coerce generation settings.""" | |
| return GenerationSettings( | |
| max_new_tokens=int(raw_config.get("max_new_tokens", 512)), | |
| temperature=float(raw_config.get("temperature", 0.3)), | |
| top_p=float(raw_config.get("top_p", 0.9)), | |
| repetition_penalty=float(raw_config.get("repetition_penalty", 1.05)), | |
| tool_prefix=str(raw_config.get("tool_prefix", "~~~toolcall")), | |
| tool_suffix=str(raw_config.get("tool_suffix", "~~~")), | |
| final_prefix=str(raw_config.get("final_prefix", "~~~final")), | |
| final_suffix=str(raw_config.get("final_suffix", "~~~")), | |
| stop_sequences=tuple(raw_config.get("stop_sequences", [])), | |
| system_prompt=str( | |
| raw_config.get( | |
| "system_prompt", | |
| "You are the NexaSci Assistant, a scientific research agent.", | |
| ) | |
| ), | |
| ) | |
| def _resolve_torch_dtype(dtype_name: Optional[str]) -> Optional[torch.dtype]: | |
| """Map configuration dtype strings to torch dtypes.""" | |
| if dtype_name is None: | |
| return None | |
| normalised = dtype_name.strip().lower() | |
| mapping = { | |
| "float16": torch.float16, | |
| "fp16": torch.float16, | |
| "half": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "bf16": torch.bfloat16, | |
| "float32": torch.float32, | |
| "fp32": torch.float32, | |
| } | |
| try: | |
| return mapping[normalised] | |
| except KeyError as exc: | |
| raise ValueError(f"Unsupported torch dtype: {dtype_name}") from exc | |
| def _load_yaml(path: Path) -> Dict[str, Any]: | |
| """Load and parse YAML configuration from disk.""" | |
| with path.open("r", encoding="utf-8") as handle: | |
| return yaml.safe_load(handle) | |
| __all__ = [ | |
| "GenerationSettings", | |
| "Message", | |
| "ModelConfig", | |
| "NexaSciModelClient", | |
| ] | |