NeuroSAM3 / inference_api.py
mmrech's picture
feat: transform NeuroSAM3 into agentic neuroimaging platform
a7e0222 unverified
Raw
History Blame Contribute Delete
11.8 kB
"""
HF Inference API Client for NeuroSAM3.
Wrapper for calling large LLMs (Gemma, Kimi, GLM) via HuggingFace Inference API.
These models are too large to load locally on HF Spaces GPU.
"""
from typing import Optional, List, Dict, Any
import json
import base64
from io import BytesIO
from PIL import Image
from logger_config import logger
from config import (
HF_TOKEN,
LLM_MODELS,
DEFAULT_LLM_PROVIDER,
INFERENCE_API_TIMEOUT,
)
try:
from huggingface_hub import InferenceClient
HF_INFERENCE_AVAILABLE = True
except ImportError:
HF_INFERENCE_AVAILABLE = False
logger.warning("huggingface_hub InferenceClient not available")
class HFInferenceAPI:
"""Wrapper for HF Inference API calls to large LLMs."""
def __init__(self, provider: str = DEFAULT_LLM_PROVIDER):
self.provider = provider
self._client: Optional[Any] = None
if not HF_INFERENCE_AVAILABLE:
logger.error("huggingface_hub not installed with inference support")
return
if not HF_TOKEN:
logger.warning("HF_TOKEN not set — Inference API will not work")
return
try:
self._client = InferenceClient(token=HF_TOKEN, timeout=INFERENCE_API_TIMEOUT)
logger.info(f"HF Inference API client initialized (provider: {provider})")
except Exception as e:
logger.error(f"Failed to initialize InferenceClient: {e}")
@property
def is_available(self) -> bool:
"""Check if the inference client is ready."""
return self._client is not None
@property
def model_id(self) -> str:
"""Get the current model ID."""
return LLM_MODELS.get(self.provider, LLM_MODELS[DEFAULT_LLM_PROVIDER])
def set_provider(self, provider: str):
"""Switch LLM provider."""
if provider in LLM_MODELS:
self.provider = provider
logger.info(f"Switched LLM provider to: {provider} ({self.model_id})")
else:
logger.warning(f"Unknown provider: {provider}. Available: {list(LLM_MODELS.keys())}")
def chat(
self,
messages: List[Dict[str, str]],
system_prompt: Optional[str] = None,
max_tokens: int = 2048,
temperature: float = 0.3,
) -> Optional[str]:
"""
Send a chat completion request to the LLM.
Args:
messages: List of {"role": "user"/"assistant", "content": "..."}
system_prompt: Optional system message prepended
max_tokens: Maximum response tokens
temperature: Sampling temperature (lower = more deterministic)
Returns:
Generated text response, or None on failure
"""
if not self.is_available:
logger.error("Inference API not available")
return None
try:
full_messages = []
if system_prompt:
full_messages.append({"role": "system", "content": system_prompt})
full_messages.extend(messages)
response = self._client.chat_completion(
model=self.model_id,
messages=full_messages,
max_tokens=max_tokens,
temperature=temperature,
)
if response and response.choices:
return response.choices[0].message.content
return None
except Exception as e:
logger.error(f"Inference API chat error ({self.provider}): {e}")
return None
def chat_with_tools(
self,
messages: List[Dict[str, str]],
tools: List[Dict[str, Any]],
system_prompt: Optional[str] = None,
max_tokens: int = 2048,
) -> Optional[Dict[str, Any]]:
"""
Send a chat request with tool-use (function calling).
Args:
messages: Chat messages
tools: List of tool definitions (OpenAI-compatible format)
system_prompt: System message
max_tokens: Max response tokens
Returns:
Dict with 'content' (text) and/or 'tool_calls' (list of tool invocations)
"""
if not self.is_available:
logger.error("Inference API not available")
return None
try:
full_messages = []
if system_prompt:
full_messages.append({"role": "system", "content": system_prompt})
full_messages.extend(messages)
response = self._client.chat_completion(
model=self.model_id,
messages=full_messages,
tools=tools,
max_tokens=max_tokens,
temperature=0.1, # Low temp for tool-calling precision
)
if not response or not response.choices:
return None
choice = response.choices[0]
result = {"content": None, "tool_calls": None}
if choice.message.content:
result["content"] = choice.message.content
if choice.message.tool_calls:
result["tool_calls"] = [
{
"name": tc.function.name,
"arguments": json.loads(tc.function.arguments)
if isinstance(tc.function.arguments, str)
else tc.function.arguments,
}
for tc in choice.message.tool_calls
]
return result
except Exception as e:
logger.error(f"Inference API tool-call error ({self.provider}): {e}")
return None
def vision_chat(
self,
text: str,
image: Image.Image,
system_prompt: Optional[str] = None,
max_tokens: int = 2048,
) -> Optional[str]:
"""
Send a vision+text request (for multimodal models like Gemma, GLM-4.1V).
Args:
text: Text prompt
image: PIL Image to analyze
system_prompt: Optional system message
max_tokens: Max response tokens
Returns:
Generated text response, or None on failure
"""
if not self.is_available:
logger.error("Inference API not available")
return None
try:
# Encode image to base64
buffered = BytesIO()
image.save(buffered, format="PNG")
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}},
{"type": "text", "text": text},
],
})
response = self._client.chat_completion(
model=self.model_id,
messages=messages,
max_tokens=max_tokens,
)
if response and response.choices:
return response.choices[0].message.content
return None
except Exception as e:
logger.error(f"Inference API vision error ({self.provider}): {e}")
return None
# Tool definitions for the agentic orchestrator
NEUROIMAGING_TOOLS = [
{
"type": "function",
"function": {
"name": "segment_with_sam3",
"description": "Segment anatomical structures using SAM3 with text prompts. Best for general structures (brain, skull, ventricles, eyes).",
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Anatomical structure to segment (e.g., 'brain', 'tumor', 'ventricles')",
},
"modality": {
"type": "string",
"enum": ["CT", "MRI"],
"description": "Imaging modality",
},
"window_type": {
"type": "string",
"enum": ["Brain (Grey Matter)", "Bone (Skull)", "Default"],
"default": "Default",
},
},
"required": ["prompt", "modality"],
},
},
},
{
"type": "function",
"function": {
"name": "segment_with_medsam",
"description": "Medical-optimized segmentation using MedSAM. Better for subtle pathology, tumors, lesions.",
"parameters": {
"type": "object",
"properties": {
"bounding_box": {
"type": "array",
"items": {"type": "integer"},
"description": "[x1, y1, x2, y2] bounding box for region of interest",
},
"modality": {
"type": "string",
"enum": ["CT", "MRI"],
},
},
"required": ["bounding_box", "modality"],
},
},
},
{
"type": "function",
"function": {
"name": "classify_image",
"description": "Zero-shot medical image classification using BiomedCLIP. Identifies modality, body region, and pathology.",
"parameters": {
"type": "object",
"properties": {
"candidate_labels": {
"type": "array",
"items": {"type": "string"},
"description": "Labels to classify against",
},
},
"required": [],
},
},
},
{
"type": "function",
"function": {
"name": "measure_roi",
"description": "Calculate ROI statistics from segmentation. Returns area, intensity stats, centroid, bounding box.",
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "What to segment and measure",
},
"modality": {
"type": "string",
"enum": ["CT", "MRI"],
},
},
"required": ["prompt", "modality"],
},
},
},
{
"type": "function",
"function": {
"name": "generate_report",
"description": "Generate structured clinical or research report from findings.",
"parameters": {
"type": "object",
"properties": {
"findings_summary": {
"type": "string",
"description": "Summary of segmentation findings and measurements",
},
"report_style": {
"type": "string",
"enum": ["radiology", "neurosurgery", "research"],
"default": "radiology",
},
"clinical_context": {
"type": "string",
"description": "Optional patient history or clinical indication",
},
},
"required": ["findings_summary"],
},
},
},
]
# Global singleton
inference_client = HFInferenceAPI()