| """ |
| HF Inference API Client for NeuroSAM3. |
| |
| Wrapper for calling large LLMs (Gemma, Kimi, GLM) via HuggingFace Inference API. |
| These models are too large to load locally on HF Spaces GPU. |
| """ |
|
|
| from typing import Optional, List, Dict, Any |
| import json |
| import base64 |
| from io import BytesIO |
| from PIL import Image |
| from logger_config import logger |
| from config import ( |
| HF_TOKEN, |
| LLM_MODELS, |
| DEFAULT_LLM_PROVIDER, |
| INFERENCE_API_TIMEOUT, |
| ) |
|
|
| try: |
| from huggingface_hub import InferenceClient |
| HF_INFERENCE_AVAILABLE = True |
| except ImportError: |
| HF_INFERENCE_AVAILABLE = False |
| logger.warning("huggingface_hub InferenceClient not available") |
|
|
|
|
| class HFInferenceAPI: |
| """Wrapper for HF Inference API calls to large LLMs.""" |
|
|
| def __init__(self, provider: str = DEFAULT_LLM_PROVIDER): |
| self.provider = provider |
| self._client: Optional[Any] = None |
|
|
| if not HF_INFERENCE_AVAILABLE: |
| logger.error("huggingface_hub not installed with inference support") |
| return |
|
|
| if not HF_TOKEN: |
| logger.warning("HF_TOKEN not set — Inference API will not work") |
| return |
|
|
| try: |
| self._client = InferenceClient(token=HF_TOKEN, timeout=INFERENCE_API_TIMEOUT) |
| logger.info(f"HF Inference API client initialized (provider: {provider})") |
| except Exception as e: |
| logger.error(f"Failed to initialize InferenceClient: {e}") |
|
|
| @property |
| def is_available(self) -> bool: |
| """Check if the inference client is ready.""" |
| return self._client is not None |
|
|
| @property |
| def model_id(self) -> str: |
| """Get the current model ID.""" |
| return LLM_MODELS.get(self.provider, LLM_MODELS[DEFAULT_LLM_PROVIDER]) |
|
|
| def set_provider(self, provider: str): |
| """Switch LLM provider.""" |
| if provider in LLM_MODELS: |
| self.provider = provider |
| logger.info(f"Switched LLM provider to: {provider} ({self.model_id})") |
| else: |
| logger.warning(f"Unknown provider: {provider}. Available: {list(LLM_MODELS.keys())}") |
|
|
| def chat( |
| self, |
| messages: List[Dict[str, str]], |
| system_prompt: Optional[str] = None, |
| max_tokens: int = 2048, |
| temperature: float = 0.3, |
| ) -> Optional[str]: |
| """ |
| Send a chat completion request to the LLM. |
| |
| Args: |
| messages: List of {"role": "user"/"assistant", "content": "..."} |
| system_prompt: Optional system message prepended |
| max_tokens: Maximum response tokens |
| temperature: Sampling temperature (lower = more deterministic) |
| |
| Returns: |
| Generated text response, or None on failure |
| """ |
| if not self.is_available: |
| logger.error("Inference API not available") |
| return None |
|
|
| try: |
| full_messages = [] |
| if system_prompt: |
| full_messages.append({"role": "system", "content": system_prompt}) |
| full_messages.extend(messages) |
|
|
| response = self._client.chat_completion( |
| model=self.model_id, |
| messages=full_messages, |
| max_tokens=max_tokens, |
| temperature=temperature, |
| ) |
|
|
| if response and response.choices: |
| return response.choices[0].message.content |
| return None |
|
|
| except Exception as e: |
| logger.error(f"Inference API chat error ({self.provider}): {e}") |
| return None |
|
|
| def chat_with_tools( |
| self, |
| messages: List[Dict[str, str]], |
| tools: List[Dict[str, Any]], |
| system_prompt: Optional[str] = None, |
| max_tokens: int = 2048, |
| ) -> Optional[Dict[str, Any]]: |
| """ |
| Send a chat request with tool-use (function calling). |
| |
| Args: |
| messages: Chat messages |
| tools: List of tool definitions (OpenAI-compatible format) |
| system_prompt: System message |
| max_tokens: Max response tokens |
| |
| Returns: |
| Dict with 'content' (text) and/or 'tool_calls' (list of tool invocations) |
| """ |
| if not self.is_available: |
| logger.error("Inference API not available") |
| return None |
|
|
| try: |
| full_messages = [] |
| if system_prompt: |
| full_messages.append({"role": "system", "content": system_prompt}) |
| full_messages.extend(messages) |
|
|
| response = self._client.chat_completion( |
| model=self.model_id, |
| messages=full_messages, |
| tools=tools, |
| max_tokens=max_tokens, |
| temperature=0.1, |
| ) |
|
|
| if not response or not response.choices: |
| return None |
|
|
| choice = response.choices[0] |
| result = {"content": None, "tool_calls": None} |
|
|
| if choice.message.content: |
| result["content"] = choice.message.content |
|
|
| if choice.message.tool_calls: |
| result["tool_calls"] = [ |
| { |
| "name": tc.function.name, |
| "arguments": json.loads(tc.function.arguments) |
| if isinstance(tc.function.arguments, str) |
| else tc.function.arguments, |
| } |
| for tc in choice.message.tool_calls |
| ] |
|
|
| return result |
|
|
| except Exception as e: |
| logger.error(f"Inference API tool-call error ({self.provider}): {e}") |
| return None |
|
|
| def vision_chat( |
| self, |
| text: str, |
| image: Image.Image, |
| system_prompt: Optional[str] = None, |
| max_tokens: int = 2048, |
| ) -> Optional[str]: |
| """ |
| Send a vision+text request (for multimodal models like Gemma, GLM-4.1V). |
| |
| Args: |
| text: Text prompt |
| image: PIL Image to analyze |
| system_prompt: Optional system message |
| max_tokens: Max response tokens |
| |
| Returns: |
| Generated text response, or None on failure |
| """ |
| if not self.is_available: |
| logger.error("Inference API not available") |
| return None |
|
|
| try: |
| |
| buffered = BytesIO() |
| image.save(buffered, format="PNG") |
| img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
| messages = [] |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
|
|
| messages.append({ |
| "role": "user", |
| "content": [ |
| {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_b64}"}}, |
| {"type": "text", "text": text}, |
| ], |
| }) |
|
|
| response = self._client.chat_completion( |
| model=self.model_id, |
| messages=messages, |
| max_tokens=max_tokens, |
| ) |
|
|
| if response and response.choices: |
| return response.choices[0].message.content |
| return None |
|
|
| except Exception as e: |
| logger.error(f"Inference API vision error ({self.provider}): {e}") |
| return None |
|
|
|
|
| |
| NEUROIMAGING_TOOLS = [ |
| { |
| "type": "function", |
| "function": { |
| "name": "segment_with_sam3", |
| "description": "Segment anatomical structures using SAM3 with text prompts. Best for general structures (brain, skull, ventricles, eyes).", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "prompt": { |
| "type": "string", |
| "description": "Anatomical structure to segment (e.g., 'brain', 'tumor', 'ventricles')", |
| }, |
| "modality": { |
| "type": "string", |
| "enum": ["CT", "MRI"], |
| "description": "Imaging modality", |
| }, |
| "window_type": { |
| "type": "string", |
| "enum": ["Brain (Grey Matter)", "Bone (Skull)", "Default"], |
| "default": "Default", |
| }, |
| }, |
| "required": ["prompt", "modality"], |
| }, |
| }, |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "segment_with_medsam", |
| "description": "Medical-optimized segmentation using MedSAM. Better for subtle pathology, tumors, lesions.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "bounding_box": { |
| "type": "array", |
| "items": {"type": "integer"}, |
| "description": "[x1, y1, x2, y2] bounding box for region of interest", |
| }, |
| "modality": { |
| "type": "string", |
| "enum": ["CT", "MRI"], |
| }, |
| }, |
| "required": ["bounding_box", "modality"], |
| }, |
| }, |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "classify_image", |
| "description": "Zero-shot medical image classification using BiomedCLIP. Identifies modality, body region, and pathology.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "candidate_labels": { |
| "type": "array", |
| "items": {"type": "string"}, |
| "description": "Labels to classify against", |
| }, |
| }, |
| "required": [], |
| }, |
| }, |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "measure_roi", |
| "description": "Calculate ROI statistics from segmentation. Returns area, intensity stats, centroid, bounding box.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "prompt": { |
| "type": "string", |
| "description": "What to segment and measure", |
| }, |
| "modality": { |
| "type": "string", |
| "enum": ["CT", "MRI"], |
| }, |
| }, |
| "required": ["prompt", "modality"], |
| }, |
| }, |
| }, |
| { |
| "type": "function", |
| "function": { |
| "name": "generate_report", |
| "description": "Generate structured clinical or research report from findings.", |
| "parameters": { |
| "type": "object", |
| "properties": { |
| "findings_summary": { |
| "type": "string", |
| "description": "Summary of segmentation findings and measurements", |
| }, |
| "report_style": { |
| "type": "string", |
| "enum": ["radiology", "neurosurgery", "research"], |
| "default": "radiology", |
| }, |
| "clinical_context": { |
| "type": "string", |
| "description": "Optional patient history or clinical indication", |
| }, |
| }, |
| "required": ["findings_summary"], |
| }, |
| }, |
| }, |
| ] |
|
|
|
|
| |
| inference_client = HFInferenceAPI() |
|
|