""" HF Inference API Client for NeuroSAM3. Wrapper for calling large LLMs (Gemma, Kimi, GLM) via HuggingFace Inference API. These models are too large to load locally on HF Spaces GPU. """ from typing import Optional, List, Dict, Any import json import base64 from io import BytesIO from PIL import Image from logger_config import logger from config import ( HF_TOKEN, LLM_MODELS, DEFAULT_LLM_PROVIDER, INFERENCE_API_TIMEOUT, ) try: from huggingface_hub import InferenceClient HF_INFERENCE_AVAILABLE = True except ImportError: HF_INFERENCE_AVAILABLE = False logger.warning("huggingface_hub InferenceClient not available") class HFInferenceAPI: """Wrapper for HF Inference API calls to large LLMs.""" def __init__(self, provider: str = DEFAULT_LLM_PROVIDER): self.provider = provider self._client: Optional[Any] = None if not HF_INFERENCE_AVAILABLE: logger.error("huggingface_hub not installed with inference support") return if not HF_TOKEN: logger.warning("HF_TOKEN not set — Inference API will not work") return try: self._client = InferenceClient(token=HF_TOKEN, timeout=INFERENCE_API_TIMEOUT) logger.info(f"HF Inference API client initialized (provider: {provider})") except Exception as e: logger.error(f"Failed to initialize InferenceClient: {e}") @property def is_available(self) -> bool: """Check if the inference client is ready.""" return self._client is not None @property def model_id(self) -> str: """Get the current model ID.""" return LLM_MODELS.get(self.provider, LLM_MODELS[DEFAULT_LLM_PROVIDER]) def set_provider(self, provider: str): """Switch LLM provider.""" if provider in LLM_MODELS: self.provider = provider logger.info(f"Switched LLM provider to: {provider} ({self.model_id})") else: logger.warning(f"Unknown provider: {provider}. Available: {list(LLM_MODELS.keys())}") def chat( self, messages: List[Dict[str, str]], system_prompt: Optional[str] = None, max_tokens: int = 2048, temperature: float = 0.3, ) -> Optional[str]: """ Send a chat completion request to the LLM. Args: messages: List of {"role": "user"/"assistant", "content": "..."} system_prompt: Optional system message prepended max_tokens: Maximum response tokens temperature: Sampling temperature (lower = more deterministic) Returns: Generated text response, or None on failure """ if not self.is_available: logger.error("Inference API not available") return None try: full_messages = [] if system_prompt: full_messages.append({"role": "system", "content": system_prompt}) full_messages.extend(messages) response = self._client.chat_completion( model=self.model_id, messages=full_messages, max_tokens=max_tokens, temperature=temperature, ) if response and response.choices: return response.choices[0].message.content return None except Exception as e: logger.error(f"Inference API chat error ({self.provider}): {e}") return None def chat_with_tools( self, messages: List[Dict[str, str]], tools: List[Dict[str, Any]], system_prompt: Optional[str] = None, max_tokens: int = 2048, ) -> Optional[Dict[str, Any]]: """ Send a chat request with tool-use (function calling). Args: messages: Chat messages tools: List of tool definitions (OpenAI-compatible format) system_prompt: System message max_tokens: Max response tokens Returns: Dict with 'content' (text) and/or 'tool_calls' (list of tool invocations) """ if not self.is_available: logger.error("Inference API not available") return None try: full_messages = [] if system_prompt: full_messages.append({"role": "system", "content": system_prompt}) full_messages.extend(messages) response = self._client.chat_completion( model=self.model_id, messages=full_messages, tools=tools, max_tokens=max_tokens, temperature=0.1, # Low temp for tool-calling precision ) if not response or not response.choices: return None choice = response.choices[0] result = {"content": None, "tool_calls": None} if choice.message.content: result["content"] = choice.message.content if choice.message.tool_calls: result["tool_calls"] = [ { "name": tc.function.name, "arguments": json.loads(tc.function.arguments) if isinstance(tc.function.arguments, str) else tc.function.arguments, } for tc in choice.message.tool_calls ] return result except Exception as e: logger.error(f"Inference API tool-call error ({self.provider}): {e}") return None def vision_chat( self, text: str, image: Image.Image, system_prompt: Optional[str] = None, max_tokens: int = 2048, ) -> Optional[str]: """ Send a vision+text request (for multimodal models like Gemma, GLM-4.1V). Args: text: Text prompt image: PIL Image to analyze system_prompt: Optional system message max_tokens: Max response tokens Returns: Generated text response, or None on failure """ if not self.is_available: logger.error("Inference API not available") return None try: # Encode image to base64 buffered = BytesIO() image.save(buffered, format="PNG") img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) messages.append({ "role": "user", "content": [ {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}, {"type": "text", "text": text}, ], }) response = self._client.chat_completion( model=self.model_id, messages=messages, max_tokens=max_tokens, ) if response and response.choices: return response.choices[0].message.content return None except Exception as e: logger.error(f"Inference API vision error ({self.provider}): {e}") return None # Tool definitions for the agentic orchestrator NEUROIMAGING_TOOLS = [ { "type": "function", "function": { "name": "segment_with_sam3", "description": "Segment anatomical structures using SAM3 with text prompts. Best for general structures (brain, skull, ventricles, eyes).", "parameters": { "type": "object", "properties": { "prompt": { "type": "string", "description": "Anatomical structure to segment (e.g., 'brain', 'tumor', 'ventricles')", }, "modality": { "type": "string", "enum": ["CT", "MRI"], "description": "Imaging modality", }, "window_type": { "type": "string", "enum": ["Brain (Grey Matter)", "Bone (Skull)", "Default"], "default": "Default", }, }, "required": ["prompt", "modality"], }, }, }, { "type": "function", "function": { "name": "segment_with_medsam", "description": "Medical-optimized segmentation using MedSAM. Better for subtle pathology, tumors, lesions.", "parameters": { "type": "object", "properties": { "bounding_box": { "type": "array", "items": {"type": "integer"}, "description": "[x1, y1, x2, y2] bounding box for region of interest", }, "modality": { "type": "string", "enum": ["CT", "MRI"], }, }, "required": ["bounding_box", "modality"], }, }, }, { "type": "function", "function": { "name": "classify_image", "description": "Zero-shot medical image classification using BiomedCLIP. Identifies modality, body region, and pathology.", "parameters": { "type": "object", "properties": { "candidate_labels": { "type": "array", "items": {"type": "string"}, "description": "Labels to classify against", }, }, "required": [], }, }, }, { "type": "function", "function": { "name": "measure_roi", "description": "Calculate ROI statistics from segmentation. Returns area, intensity stats, centroid, bounding box.", "parameters": { "type": "object", "properties": { "prompt": { "type": "string", "description": "What to segment and measure", }, "modality": { "type": "string", "enum": ["CT", "MRI"], }, }, "required": ["prompt", "modality"], }, }, }, { "type": "function", "function": { "name": "generate_report", "description": "Generate structured clinical or research report from findings.", "parameters": { "type": "object", "properties": { "findings_summary": { "type": "string", "description": "Summary of segmentation findings and measurements", }, "report_style": { "type": "string", "enum": ["radiology", "neurosurgery", "research"], "default": "radiology", }, "clinical_context": { "type": "string", "description": "Optional patient history or clinical indication", }, }, "required": ["findings_summary"], }, }, }, ] # Global singleton inference_client = HFInferenceAPI()