Final_Assignment_Template / tools /classifier_tool.py
FD900's picture
Update tools/classifier_tool.py
9b2def5 verified
raw
history blame
2.36 kB
from smolagents import Tool
import requests
import os
class ClassifyItemsTool(Tool):
name = 'hf_classifier'
description = """Categorizes a list of items using domain-specific knowledge. Intended for structured classification tasks."""
inputs = {
'domain': {
'type': 'string',
'description': 'The knowledge domain to use for classification (e.g., medicine, education, etc.).'
},
'context': {
'type': 'string',
'description': 'A brief description of the environment where items appear (helps disambiguate meanings).'
},
'categories': {
'type': 'string',
'description': 'Comma-separated category list to classify items into.'
},
'items': {
'type': 'string',
'description': 'Comma-separated list of items to classify.'
},
}
output_type = 'string'
def __init__(self, hf_model_url: str | None = None, **kwargs):
self.api_url = hf_model_url or os.getenv("HF_ENDPOINT")
if not self.api_url:
raise ValueError("HF_ENDPOINT must be set as environment variable or passed to constructor.")
super().__init__(**kwargs)
def forward(self, domain: str, context: str, categories: str, items: str) -> str:
prompt = self._build_prompt(domain, context, categories, items)
response = requests.post(
self.api_url,
headers={"Content-Type": "application/json"},
json={"inputs": prompt}
)
response.raise_for_status()
result = response.json()
if isinstance(result, list) and 'generated_text' in result[0]:
return result[0]['generated_text'].strip()
elif isinstance(result, dict) and 'generated_text' in result:
return result['generated_text'].strip()
return str(result)
def _build_prompt(self, domain: str, context: str, categories: str, items: str) -> str:
return f"""
You are a {domain} expert working within a {context} context.
Classify the following items into the specified categories using your domain expertise.
Categories:
{categories}
Items to classify:
{items}
Return explanation followed by a classification in this format:
Category 1: item list
Category 2: item list
Other (if uncertain): item list
"""