Minibase commited on
Commit
827270e
·
verified ·
1 Parent(s): c9881c6

Upload ner_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ner_inference.py +253 -0
ner_inference.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ NER-Small Inference Client
4
+
5
+ A Python client for running inference with the Minibase-NER-Small model.
6
+ Handles named entity recognition requests to the local llama.cpp server.
7
+ """
8
+
9
+ import requests
10
+ import json
11
+ from typing import Optional, Dict, Any, Tuple, List
12
+ import time
13
+ import re
14
+
15
+
16
+ class NERClient:
17
+ """
18
+ Client for the NER-Small named entity recognition model.
19
+
20
+ This client communicates with a local llama.cpp server running the
21
+ Minibase-NER-Small model for named entity recognition tasks.
22
+ """
23
+
24
+ def __init__(self, base_url: str = "http://127.0.0.1:8000", timeout: int = 30):
25
+ """
26
+ Initialize the NER client.
27
+
28
+ Args:
29
+ base_url: Base URL of the llama.cpp server
30
+ timeout: Request timeout in seconds
31
+ """
32
+ self.base_url = base_url.rstrip('/')
33
+ self.timeout = timeout
34
+ self.default_instruction = "Extract all named entities from the following text. List them as 1. Entity, 2. Entity, etc."
35
+
36
+ def _make_request(self, prompt: str, max_tokens: int = 512,
37
+ temperature: float = 0.1) -> Tuple[str, float]:
38
+ """
39
+ Make a completion request to the model.
40
+
41
+ Args:
42
+ prompt: The input prompt
43
+ max_tokens: Maximum tokens to generate
44
+ temperature: Sampling temperature
45
+
46
+ Returns:
47
+ Tuple of (response_text, latency_ms)
48
+ """
49
+ payload = {
50
+ "prompt": prompt,
51
+ "max_tokens": max_tokens,
52
+ "temperature": temperature
53
+ }
54
+
55
+ headers = {'Content-Type': 'application/json'}
56
+
57
+ start_time = time.time()
58
+ try:
59
+ response = requests.post(
60
+ f"{self.base_url}/completion",
61
+ json=payload,
62
+ headers=headers,
63
+ timeout=self.timeout
64
+ )
65
+
66
+ latency = (time.time() - start_time) * 1000 # Convert to milliseconds
67
+
68
+ if response.status_code == 200:
69
+ result = response.json()
70
+ return result.get('content', ''), latency
71
+ else:
72
+ return f"Error: HTTP {response.status_code}", latency
73
+
74
+ except requests.exceptions.RequestException as e:
75
+ latency = (time.time() - start_time) * 1000
76
+ return f"Error: {e}", latency
77
+
78
+ def extract_entities(self, text: str, instruction: Optional[str] = None,
79
+ max_tokens: int = 512, temperature: float = 0.1) -> List[Dict[str, Any]]:
80
+ """
81
+ Extract named entities from text.
82
+
83
+ Args:
84
+ text: Input text to analyze
85
+ instruction: Custom instruction (uses default if None)
86
+ max_tokens: Maximum tokens to generate
87
+ temperature: Sampling temperature
88
+
89
+ Returns:
90
+ List of entity dictionaries with text and metadata
91
+ """
92
+ if instruction is None:
93
+ instruction = self.default_instruction
94
+
95
+ prompt = f"{instruction}\n\nInput: {text}\n\nResponse: "
96
+
97
+ response_text, latency = self._make_request(prompt, max_tokens, temperature)
98
+
99
+ if response_text.startswith("Error"):
100
+ return []
101
+
102
+ # Parse the numbered list response
103
+ entities = self._parse_entity_response(response_text)
104
+
105
+ # Add metadata to each entity
106
+ for entity in entities:
107
+ entity.update({
108
+ 'confidence': 1.0, # Placeholder - model doesn't provide confidence
109
+ 'latency_ms': latency
110
+ })
111
+
112
+ return entities
113
+
114
+ def extract_entities_batch(self, texts: List[str], instruction: Optional[str] = None,
115
+ max_tokens: int = 512, temperature: float = 0.1) -> List[List[Dict[str, Any]]]:
116
+ """
117
+ Extract named entities from multiple texts.
118
+
119
+ Args:
120
+ texts: List of input texts to analyze
121
+ instruction: Custom instruction (uses default if None)
122
+ max_tokens: Maximum tokens to generate
123
+ temperature: Sampling temperature
124
+
125
+ Returns:
126
+ List of entity lists, one per input text
127
+ """
128
+ results = []
129
+ for text in texts:
130
+ entities = self.extract_entities(text, instruction, max_tokens, temperature)
131
+ results.append(entities)
132
+
133
+ return results
134
+
135
+ def _parse_entity_response(self, response_text: str) -> List[Dict[str, Any]]:
136
+ """
137
+ Parse the model's numbered list response into structured entities.
138
+
139
+ Args:
140
+ response_text: Raw model response
141
+
142
+ Returns:
143
+ List of entity dictionaries
144
+ """
145
+ entities = []
146
+
147
+ # Clean up the response
148
+ response_text = response_text.strip()
149
+
150
+ # Split by lines and process each line
151
+ lines = response_text.split('\n')
152
+
153
+ for line in lines:
154
+ line = line.strip()
155
+ if not line:
156
+ continue
157
+
158
+ # Try to extract entity names from numbered list format
159
+ # Pattern 1: "1. Entity Name" or "1. Entity Name - Description"
160
+ numbered_match = re.match(r'^\d+\.\s*(.+?)(?:\s*-\s*.+)?$', line)
161
+ if numbered_match:
162
+ entity_text = numbered_match.group(1).strip()
163
+ # Remove any trailing punctuation
164
+ entity_text = re.sub(r'[.,;:!?]$', '', entity_text).strip()
165
+ # Skip very short entities or generic terms
166
+ if entity_text and len(entity_text) > 1 and not entity_text.lower() in ['the', 'and', 'or', 'but', 'for', 'with']:
167
+ entities.append({
168
+ 'text': entity_text,
169
+ 'type': 'ENTITY', # Model doesn't specify types
170
+ 'start': 0, # Position information not available
171
+ 'end': 0
172
+ })
173
+
174
+ return entities
175
+
176
+ def health_check(self) -> bool:
177
+ """
178
+ Check if the model server is healthy and responding.
179
+
180
+ Returns:
181
+ True if server is healthy, False otherwise
182
+ """
183
+ try:
184
+ response = requests.get(f"{self.base_url}/health", timeout=5)
185
+ return response.status_code == 200
186
+ except:
187
+ return False
188
+
189
+ def get_model_info(self) -> Optional[Dict[str, Any]]:
190
+ """
191
+ Get information about the loaded model.
192
+
193
+ Returns:
194
+ Model information dictionary or None if unavailable
195
+ """
196
+ try:
197
+ response = requests.get(f"{self.base_url}/v1/models", timeout=5)
198
+ if response.status_code == 200:
199
+ return response.json()
200
+ except:
201
+ pass
202
+ return None
203
+
204
+
205
+ def main():
206
+ """
207
+ Command-line interface for NER inference.
208
+ """
209
+ import argparse
210
+
211
+ parser = argparse.ArgumentParser(description='NER-Small Inference Client')
212
+ parser.add_argument('text', help='Text to analyze for named entities')
213
+ parser.add_argument('--url', default='http://127.0.0.1:8000',
214
+ help='Model server URL (default: http://127.0.0.1:8000)')
215
+ parser.add_argument('--max-tokens', type=int, default=512,
216
+ help='Maximum tokens to generate (default: 512)')
217
+ parser.add_argument('--temperature', type=float, default=0.1,
218
+ help='Sampling temperature (default: 0.1)')
219
+
220
+ args = parser.parse_args()
221
+
222
+ # Initialize client
223
+ client = NERClient(args.url)
224
+
225
+ # Check server health
226
+ if not client.health_check():
227
+ print(f"❌ Error: Cannot connect to model server at {args.url}")
228
+ print("Make sure the llama.cpp server is running with the NER-Small model.")
229
+ return 1
230
+
231
+ # Extract entities
232
+ entities = client.extract_entities(
233
+ args.text,
234
+ max_tokens=args.max_tokens,
235
+ temperature=args.temperature
236
+ )
237
+
238
+ # Display results
239
+ print(f"📝 Input Text: {args.text}")
240
+ print(f"🎯 Found {len(entities)} entities:")
241
+ print()
242
+
243
+ if entities:
244
+ for i, entity in enumerate(entities, 1):
245
+ print(f"{i}. {entity['text']} (Type: {entity['type']})")
246
+ else:
247
+ print("No entities found.")
248
+
249
+ return 0
250
+
251
+
252
+ if __name__ == "__main__":
253
+ exit(main())