File size: 2,357 Bytes
9b2def5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from smolagents import Tool
import requests
import os

class ClassifyItemsTool(Tool):
    name = 'hf_classifier'
    description = """Categorizes a list of items using domain-specific knowledge. Intended for structured classification tasks."""

    inputs = {
        'domain': {
            'type': 'string',
            'description': 'The knowledge domain to use for classification (e.g., medicine, education, etc.).'
        },
        'context': {
            'type': 'string',
            'description': 'A brief description of the environment where items appear (helps disambiguate meanings).'
        },
        'categories': {
            'type': 'string',
            'description': 'Comma-separated category list to classify items into.'
        },
        'items': {
            'type': 'string',
            'description': 'Comma-separated list of items to classify.'
        },
    }
    output_type = 'string'

    def __init__(self, hf_model_url: str | None = None, **kwargs):
        self.api_url = hf_model_url or os.getenv("HF_ENDPOINT")
        if not self.api_url:
            raise ValueError("HF_ENDPOINT must be set as environment variable or passed to constructor.")
        super().__init__(**kwargs)

    def forward(self, domain: str, context: str, categories: str, items: str) -> str:
        prompt = self._build_prompt(domain, context, categories, items)

        response = requests.post(
            self.api_url,
            headers={"Content-Type": "application/json"},
            json={"inputs": prompt}
        )
        response.raise_for_status()
        result = response.json()

        if isinstance(result, list) and 'generated_text' in result[0]:
            return result[0]['generated_text'].strip()
        elif isinstance(result, dict) and 'generated_text' in result:
            return result['generated_text'].strip()
        return str(result)

    def _build_prompt(self, domain: str, context: str, categories: str, items: str) -> str:
        return f"""
You are a {domain} expert working within a {context} context.
Classify the following items into the specified categories using your domain expertise.

Categories:
{categories}

Items to classify:
{items}

Return explanation followed by a classification in this format:
Category 1: item list
Category 2: item list
Other (if uncertain): item list
"""