| """ |
| Extract Model Info Node |
| |
| This node handles the extraction of model information from user queries. |
| It uses LLM to extract HuggingFace model names and fetches model metadata from the API. |
| |
| Key Features: |
| - LLM-based model name extraction |
| - HuggingFace API integration |
| - Error handling for invalid models |
| - State management for workflow |
| |
| Author: ComputeAgent Team |
| License: Private |
| """ |
|
|
| import logging |
| from typing import Dict, Any, Optional |
| import json |
| import aiohttp |
| from constant import Constants |
| from ComputeAgent.models.model_manager import ModelManager |
| from langchain_core.messages import HumanMessage, SystemMessage |
| from transformers import AutoConfig |
|
|
| |
| model_manager = ModelManager() |
|
|
| logger = logging.getLogger("ExtractModelInfo") |
|
|
|
|
| async def extract_model_info_node(state: Dict[str, Any]) -> Dict[str, Any]: |
| """ |
| Extract model information from user query and fetch model details. |
| |
| This node: |
| 1. Extracts model name from query using LLM |
| 2. Fetches model info from HuggingFace API |
| 3. Updates state with model information or error status |
| |
| Args: |
| state: Current workflow state containing query |
| |
| Returns: |
| Updated state with model information or extraction status |
| """ |
| logger.info("🔍 Starting model information extraction") |
| |
| try: |
| |
| llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_NAME) |
| |
| |
| query = state.get("query", "") |
| logger.info(f"📝 Processing query: {query}") |
| |
| model_name = await extract_model_name_with_llm(query, llm) |
|
|
| if model_name == "UNKNOWN" or not model_name: |
| logger.info("❓ Model name not found, will need generation") |
| state["model_extraction_status"] = "unknown" |
| state["needs_generation"] = True |
| return state |
| |
| logger.info(f"📋 Extracted model name: {model_name}") |
| |
| |
| model_info = await fetch_huggingface_model_info_for_memory(model_name, llm) |
| |
| if "error" in model_info: |
| logger.error(f"❌ Error fetching model info: {model_info['error']}") |
| state["model_extraction_status"] = "error" |
| state["error"] = model_info["error"] |
| return state |
| |
| |
| state["model_name"] = model_name |
| state["model_info"] = model_info |
| state["model_extraction_status"] = "success" |
| state["needs_generation"] = False |
| |
| logger.info(f"✅ Successfully extracted model info for {model_name}") |
| return state |
| |
| except Exception as e: |
| logger.error(f"❌ Error during model info extraction: {str(e)}") |
| state["model_extraction_status"] = "error" |
| state["error"] = f"Model info extraction failed: {str(e)}" |
| return state |
|
|
|
|
| async def extract_model_name_with_llm(query: str, llm) -> str: |
| """ |
| Use LLM to extract HuggingFace model name from user query. |
| |
| Args: |
| query: User's natural language query |
| llm: LangChain LLM instance |
| |
| Returns: |
| Extracted model name in format 'owner/model-name' or None |
| """ |
| system_prompt = """ |
| You are an expert at extracting HuggingFace model names from user queries. |
| |
| Extract the exact HuggingFace model identifier in the format 'owner/model-name'. |
| NEVER fabricate or guess model names. Only extract what is explicitly mentioned in the query. |
| |
| Rule for the UNKNOWN response: |
| - If the model name is written but not the owner, respond with 'UNKNOWN'. |
| - If the owner is written but not the model name, respond with 'UNKNOWN'. |
| |
| Only respond with the model identifier, nothing else. |
| """ |
|
|
| messages = [ |
| SystemMessage(content=system_prompt), |
| HumanMessage(content=f"Extract the HuggingFace model name from: {query}") |
| ] |
| |
| response = await llm.ainvoke(messages) |
| model_name = response.content.strip() |
| |
| if model_name == "UNKNOWN": |
| return None |
| |
| return model_name |
|
|
|
|
| async def extract_model_dtype_with_llm(model_name: str, parameters_dict: dict, llm) -> Optional[str]: |
| """ |
| Use LLM to extract the correct dtype from model name and available parameters. |
| |
| Args: |
| model_name: HuggingFace model name in format 'owner/model-name' |
| parameters_dict: Available dtypes and their parameter counts from HF API |
| llm: LangChain LLM instance |
| |
| Returns: |
| Matching dtype key from parameters_dict or None if cannot be determined |
| """ |
| system_prompt = f""" |
| You are an expert at identifying data types from HuggingFace model names. |
| |
| Given a model name and available dtype options, determine which dtype the model uses. |
| |
| Available dtypes: {json.dumps(list(parameters_dict.keys()))} |
| |
| Rules: |
| - Analyze the model name for dtype indicators (FP8, BF16, INT4, INT8, FP16, etc.) |
| - If no dtype indicator is found in the model name by default is BF16 on the model name side. |
| - Return ONLY the dtype key that exists in the available options, nothing else |
| |
| Only respond with the dtype key or 'UNKNOWN', nothing else. |
| """ |
|
|
| messages = [ |
| SystemMessage(content=system_prompt), |
| HumanMessage(content=f"Extract the dtype from model name: {model_name}") |
| ] |
| |
| response = await llm.ainvoke(messages) |
| dtype = response.content.strip() |
| |
| |
| if dtype not in parameters_dict: |
| logger.warning(f"LLM returned dtype '{dtype}' not in available options: {list(parameters_dict.keys())}") |
| return None |
| |
| return dtype |
|
|
|
|
| async def fetch_huggingface_model_info(model_name: str) -> Dict[str, Any]: |
| """ |
| Fetch model information from HuggingFace API. |
| |
| Args: |
| model_name: HuggingFace model identifier (e.g., 'meta-llama/Meta-Llama-3-70B') |
| |
| Returns: |
| Dictionary containing model information |
| """ |
| api_url = f"https://huggingface.co/api/models/{model_name}" |
| |
| async with aiohttp.ClientSession() as session: |
| try: |
| async with session.get(api_url) as response: |
| if response.status == 200: |
| model_info = await response.json() |
| logger.info(f"✅ Successfully fetched model info for {model_name}") |
| return model_info |
| elif response.status == 404: |
| logger.error(f"❌ Model not found: {model_name}") |
| return {"error": "Model not found", "status": 404} |
| else: |
| logger.error(f"❌ API error: {response.status}") |
| return {"error": f"API error: {response.status}", "status": response.status} |
| except Exception as e: |
| logger.error(f"❌ Exception while fetching model info: {str(e)}") |
| return {"error": str(e)} |
| |
| async def fetch_huggingface_model_info_for_memory(model_name: str, llm) -> Dict[str, Any]: |
| """ |
| Fetch only the information needed for GPU memory estimation from HuggingFace. |
| |
| Returns a dictionary containing: |
| - num_params |
| - dtype |
| - num_hidden_layers |
| - hidden_size |
| - intermediate_size |
| - num_attention_heads |
| - head_dim |
| - max_position_embeddings |
| """ |
| result: Dict[str, Any] = {} |
|
|
| |
| api_url = f"https://huggingface.co/api/models/{model_name}" |
| async with aiohttp.ClientSession() as session: |
| try: |
| async with session.get(api_url) as response: |
| if response.status == 200: |
| metadata = await response.json() |
| else: |
| logger.error(f"❌ API error {response.status} for {model_name}") |
| return {} |
| except Exception as e: |
| logger.error(f"❌ Exception fetching metadata for {model_name}: {str(e)}") |
| return {} |
|
|
| |
| safetensors = metadata.get("safetensors", {}) |
| parameters_dict = safetensors.get("parameters", {}) |
|
|
| result["location"] = "UAE-1" |
| result["GPU_type"] = "RTX4090" |
|
|
| |
| if parameters_dict: |
| result["dtype"] = await extract_model_dtype_with_llm(model_name, parameters_dict, llm) |
| |
| if result["dtype"]: |
| result["num_params"] = parameters_dict[result["dtype"]] |
| logger.info(f"✓ LLM selected dtype: {result['dtype']}") |
| else: |
| |
| result["dtype"] = next(iter(parameters_dict.keys())) |
| result["num_params"] = parameters_dict[result["dtype"]] |
| logger.warning(f"⚠ Using fallback dtype: {result['dtype']}") |
| else: |
| result["dtype"] = "auto" |
| result["num_params"] = metadata.get("num_params") or safetensors.get("total") |
|
|
| |
| |
| try: |
| |
| token = Constants.HF_TOKEN if hasattr(Constants, 'HF_TOKEN') and Constants.HF_TOKEN else None |
| |
| if not token: |
| logger.warning(f"⚠️ No HF_TOKEN provided for {model_name}") |
| |
| config = AutoConfig.from_pretrained( |
| model_name, |
| token=token, |
| trust_remote_code=True |
| ) |
| |
| result.update({ |
| "num_hidden_layers": getattr(config, "num_hidden_layers", None), |
| "hidden_size": getattr(config, "hidden_size", None), |
| "intermediate_size": getattr(config, "intermediate_size", None), |
| "num_attention_heads": getattr(config, "num_attention_heads", None), |
| "num_key_value_heads": getattr(config, "num_key_value_heads", None), |
| "max_position_embeddings": getattr(config, "max_position_embeddings", None), |
| }) |
| |
| |
| if result["num_key_value_heads"] is None and result["num_attention_heads"] is not None: |
| result["num_key_value_heads"] = result["num_attention_heads"] |
| logger.info(f"ℹ️ Using num_attention_heads as num_key_value_heads for {model_name}") |
| |
| |
| if result["hidden_size"] and result["num_attention_heads"]: |
| result["head_dim"] = result["hidden_size"] // result["num_attention_heads"] |
| |
| except Exception as e: |
| logger.warning(f"⚠️ Could not fetch model config for {model_name}: {str(e)}") |
|
|
| return result |