FD900 commited on
Commit
9b2def5
·
verified ·
1 Parent(s): e0152b8

Update tools/classifier_tool.py

Browse files
Files changed (1) hide show
  1. tools/classifier_tool.py +67 -0
tools/classifier_tool.py CHANGED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import Tool
2
+ import requests
3
+ import os
4
+
5
+ class ClassifyItemsTool(Tool):
6
+ name = 'hf_classifier'
7
+ description = """Categorizes a list of items using domain-specific knowledge. Intended for structured classification tasks."""
8
+
9
+ inputs = {
10
+ 'domain': {
11
+ 'type': 'string',
12
+ 'description': 'The knowledge domain to use for classification (e.g., medicine, education, etc.).'
13
+ },
14
+ 'context': {
15
+ 'type': 'string',
16
+ 'description': 'A brief description of the environment where items appear (helps disambiguate meanings).'
17
+ },
18
+ 'categories': {
19
+ 'type': 'string',
20
+ 'description': 'Comma-separated category list to classify items into.'
21
+ },
22
+ 'items': {
23
+ 'type': 'string',
24
+ 'description': 'Comma-separated list of items to classify.'
25
+ },
26
+ }
27
+ output_type = 'string'
28
+
29
+ def __init__(self, hf_model_url: str | None = None, **kwargs):
30
+ self.api_url = hf_model_url or os.getenv("HF_ENDPOINT")
31
+ if not self.api_url:
32
+ raise ValueError("HF_ENDPOINT must be set as environment variable or passed to constructor.")
33
+ super().__init__(**kwargs)
34
+
35
+ def forward(self, domain: str, context: str, categories: str, items: str) -> str:
36
+ prompt = self._build_prompt(domain, context, categories, items)
37
+
38
+ response = requests.post(
39
+ self.api_url,
40
+ headers={"Content-Type": "application/json"},
41
+ json={"inputs": prompt}
42
+ )
43
+ response.raise_for_status()
44
+ result = response.json()
45
+
46
+ if isinstance(result, list) and 'generated_text' in result[0]:
47
+ return result[0]['generated_text'].strip()
48
+ elif isinstance(result, dict) and 'generated_text' in result:
49
+ return result['generated_text'].strip()
50
+ return str(result)
51
+
52
+ def _build_prompt(self, domain: str, context: str, categories: str, items: str) -> str:
53
+ return f"""
54
+ You are a {domain} expert working within a {context} context.
55
+ Classify the following items into the specified categories using your domain expertise.
56
+
57
+ Categories:
58
+ {categories}
59
+
60
+ Items to classify:
61
+ {items}
62
+
63
+ Return explanation followed by a classification in this format:
64
+ Category 1: item list
65
+ Category 2: item list
66
+ Other (if uncertain): item list
67
+ """