carraraig's picture
finish (#8)
5dd4236 verified
"""
Extract Model Info Node
This node handles the extraction of model information from user queries.
It uses LLM to extract HuggingFace model names and fetches model metadata from the API.
Key Features:
- LLM-based model name extraction
- HuggingFace API integration
- Error handling for invalid models
- State management for workflow
Author: ComputeAgent Team
License: Private
"""
import logging
from typing import Dict, Any, Optional
import json
import aiohttp
from constant import Constants
from ComputeAgent.models.model_manager import ModelManager
from langchain_core.messages import HumanMessage, SystemMessage
from transformers import AutoConfig
# Initialize model manager for dynamic LLM loading and management
model_manager = ModelManager()
logger = logging.getLogger("ExtractModelInfo")
async def extract_model_info_node(state: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract model information from user query and fetch model details.
This node:
1. Extracts model name from query using LLM
2. Fetches model info from HuggingFace API
3. Updates state with model information or error status
Args:
state: Current workflow state containing query
Returns:
Updated state with model information or extraction status
"""
logger.info("🔍 Starting model information extraction")
try:
# Initialize LLM
llm = await model_manager.load_llm_model(Constants.DEFAULT_LLM_NAME)
# Extract model name from query using LLM
query = state.get("query", "")
logger.info(f"📝 Processing query: {query}")
model_name = await extract_model_name_with_llm(query, llm)
if model_name == "UNKNOWN" or not model_name:
logger.info("❓ Model name not found, will need generation")
state["model_extraction_status"] = "unknown"
state["needs_generation"] = True
return state
logger.info(f"📋 Extracted model name: {model_name}")
# Fetch model information
model_info = await fetch_huggingface_model_info_for_memory(model_name, llm)
if "error" in model_info:
logger.error(f"❌ Error fetching model info: {model_info['error']}")
state["model_extraction_status"] = "error"
state["error"] = model_info["error"]
return state
# Success - update state with model information
state["model_name"] = model_name
state["model_info"] = model_info
state["model_extraction_status"] = "success"
state["needs_generation"] = False
logger.info(f"✅ Successfully extracted model info for {model_name}")
return state
except Exception as e:
logger.error(f"❌ Error during model info extraction: {str(e)}")
state["model_extraction_status"] = "error"
state["error"] = f"Model info extraction failed: {str(e)}"
return state
async def extract_model_name_with_llm(query: str, llm) -> str:
"""
Use LLM to extract HuggingFace model name from user query.
Args:
query: User's natural language query
llm: LangChain LLM instance
Returns:
Extracted model name in format 'owner/model-name' or None
"""
system_prompt = """
You are an expert at extracting HuggingFace model names from user queries.
Extract the exact HuggingFace model identifier in the format 'owner/model-name'.
NEVER fabricate or guess model names. Only extract what is explicitly mentioned in the query.
Rule for the UNKNOWN response:
- If the model name is written but not the owner, respond with 'UNKNOWN'.
- If the owner is written but not the model name, respond with 'UNKNOWN'.
Only respond with the model identifier, nothing else.
"""
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=f"Extract the HuggingFace model name from: {query}")
]
response = await llm.ainvoke(messages)
model_name = response.content.strip()
if model_name == "UNKNOWN":
return None
return model_name
async def extract_model_dtype_with_llm(model_name: str, parameters_dict: dict, llm) -> Optional[str]:
"""
Use LLM to extract the correct dtype from model name and available parameters.
Args:
model_name: HuggingFace model name in format 'owner/model-name'
parameters_dict: Available dtypes and their parameter counts from HF API
llm: LangChain LLM instance
Returns:
Matching dtype key from parameters_dict or None if cannot be determined
"""
system_prompt = f"""
You are an expert at identifying data types from HuggingFace model names.
Given a model name and available dtype options, determine which dtype the model uses.
Available dtypes: {json.dumps(list(parameters_dict.keys()))}
Rules:
- Analyze the model name for dtype indicators (FP8, BF16, INT4, INT8, FP16, etc.)
- If no dtype indicator is found in the model name by default is BF16 on the model name side.
- Return ONLY the dtype key that exists in the available options, nothing else
Only respond with the dtype key or 'UNKNOWN', nothing else.
"""
messages = [
SystemMessage(content=system_prompt),
HumanMessage(content=f"Extract the dtype from model name: {model_name}")
]
response = await llm.ainvoke(messages)
dtype = response.content.strip()
# Validate that the returned dtype exists in parameters_dict
if dtype not in parameters_dict:
logger.warning(f"LLM returned dtype '{dtype}' not in available options: {list(parameters_dict.keys())}")
return None
return dtype
async def fetch_huggingface_model_info(model_name: str) -> Dict[str, Any]:
"""
Fetch model information from HuggingFace API.
Args:
model_name: HuggingFace model identifier (e.g., 'meta-llama/Meta-Llama-3-70B')
Returns:
Dictionary containing model information
"""
api_url = f"https://huggingface.co/api/models/{model_name}"
async with aiohttp.ClientSession() as session:
try:
async with session.get(api_url) as response:
if response.status == 200:
model_info = await response.json()
logger.info(f"✅ Successfully fetched model info for {model_name}")
return model_info
elif response.status == 404:
logger.error(f"❌ Model not found: {model_name}")
return {"error": "Model not found", "status": 404}
else:
logger.error(f"❌ API error: {response.status}")
return {"error": f"API error: {response.status}", "status": response.status}
except Exception as e:
logger.error(f"❌ Exception while fetching model info: {str(e)}")
return {"error": str(e)}
async def fetch_huggingface_model_info_for_memory(model_name: str, llm) -> Dict[str, Any]:
"""
Fetch only the information needed for GPU memory estimation from HuggingFace.
Returns a dictionary containing:
- num_params
- dtype
- num_hidden_layers
- hidden_size
- intermediate_size
- num_attention_heads
- head_dim
- max_position_embeddings
"""
result: Dict[str, Any] = {}
# Step 1: Fetch metadata from HuggingFace API
api_url = f"https://huggingface.co/api/models/{model_name}"
async with aiohttp.ClientSession() as session:
try:
async with session.get(api_url) as response:
if response.status == 200:
metadata = await response.json()
else:
logger.error(f"❌ API error {response.status} for {model_name}")
return {}
except Exception as e:
logger.error(f"❌ Exception fetching metadata for {model_name}: {str(e)}")
return {}
# Extract num_params and dtype
safetensors = metadata.get("safetensors", {})
parameters_dict = safetensors.get("parameters", {})
result["location"] = "UAE-1" # Default location
result["GPU_type"] = "RTX4090" # Default GPU type
# Usage in your main code:
if parameters_dict:
result["dtype"] = await extract_model_dtype_with_llm(model_name, parameters_dict, llm)
if result["dtype"]:
result["num_params"] = parameters_dict[result["dtype"]]
logger.info(f"✓ LLM selected dtype: {result['dtype']}")
else:
# Fallback to first available if LLM couldn't determine
result["dtype"] = next(iter(parameters_dict.keys()))
result["num_params"] = parameters_dict[result["dtype"]]
logger.warning(f"⚠ Using fallback dtype: {result['dtype']}")
else:
result["dtype"] = "auto"
result["num_params"] = metadata.get("num_params") or safetensors.get("total")
# Step 2: Fetch model config via transformers
# Step 2: Fetch model config via transformers
try:
# Check if token is available
token = Constants.HF_TOKEN if hasattr(Constants, 'HF_TOKEN') and Constants.HF_TOKEN else None
if not token:
logger.warning(f"⚠️ No HF_TOKEN provided for {model_name}")
config = AutoConfig.from_pretrained(
model_name,
token=token,
trust_remote_code=True # Add this if model uses custom code
)
result.update({
"num_hidden_layers": getattr(config, "num_hidden_layers", None),
"hidden_size": getattr(config, "hidden_size", None),
"intermediate_size": getattr(config, "intermediate_size", None),
"num_attention_heads": getattr(config, "num_attention_heads", None),
"num_key_value_heads": getattr(config, "num_key_value_heads", None), # Added
"max_position_embeddings": getattr(config, "max_position_embeddings", None),
})
# Fallback: if num_key_value_heads is not available, use num_attention_heads
if result["num_key_value_heads"] is None and result["num_attention_heads"] is not None:
result["num_key_value_heads"] = result["num_attention_heads"]
logger.info(f"ℹ️ Using num_attention_heads as num_key_value_heads for {model_name}")
# Optional: compute head_dim
if result["hidden_size"] and result["num_attention_heads"]:
result["head_dim"] = result["hidden_size"] // result["num_attention_heads"]
except Exception as e:
logger.warning(f"⚠️ Could not fetch model config for {model_name}: {str(e)}")
return result