rustemgareev commited on
Commit
eb59cf9
·
1 Parent(s): f3e233f

Upload app files

Browse files
Files changed (4) hide show
  1. Dockerfile +20 -0
  2. app.py +282 -0
  3. requirements.txt +6 -0
  4. static/index.html +301 -0
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ ENV PYTHONDONTWRITEBYTECODE=1 \
6
+ PYTHONUNBUFFERED=1 \
7
+ TRANSFORMERS_CACHE=/app/cache \
8
+ HF_HOME=/app/cache
9
+
10
+ RUN mkdir -p /app/cache && chmod -R 777 /app/cache
11
+
12
+ COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ COPY app.py .
16
+
17
+ RUN useradd -m -u 1000 user
18
+ USER user
19
+
20
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from contextlib import asynccontextmanager
3
+ from typing import List, Dict, Any
4
+ from fastapi import FastAPI, HTTPException
5
+ from fastapi.staticfiles import StaticFiles
6
+ from pydantic import BaseModel, Field
7
+ from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
8
+
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ # Config
13
+ class NERRequest(BaseModel):
14
+ text: str = Field(..., title="Input Text", description="Text to analyze")
15
+
16
+ class NEREntity(BaseModel):
17
+ entity_group: str
18
+ score: float
19
+ word: str
20
+ start: int
21
+ end: int
22
+
23
+ class NERResponse(BaseModel):
24
+ entities: List[NEREntity]
25
+
26
+ # Constants
27
+ SHORT_TEXT_THRESHOLD = 128
28
+ MODEL_MAX_LENGTH = 512
29
+ WINDOW_OVERLAP = 128
30
+
31
+ # Core Logic
32
+ def refine_boundaries(text: str, start: int, end: int) -> (int, int, str):
33
+ """
34
+ Adjusts start/end indices.
35
+ 1. Expands selection to the end of the word if the model stopped mid-word.
36
+ 2. Trims leading/trailing whitespace.
37
+ """
38
+
39
+ while end < len(text) and text[end].isalnum():
40
+ end += 1
41
+
42
+ # while end < len(text) and (text[end].isalnum() or text[end] == '-'):
43
+ # end += 1
44
+
45
+ span = text[start:end]
46
+
47
+ # Shift start index forward if there is leading whitespace
48
+ while span and span[0].isspace():
49
+ start += 1
50
+ span = span[1:]
51
+
52
+ # Shift end index backward if there is trailing whitespace
53
+ while span and span[-1].isspace():
54
+ end -= 1
55
+ span = span[:-1]
56
+
57
+ return start, end, span
58
+
59
+ def refine_boundaries1(text: str, start: int, end: int) -> (int, int, str):
60
+ """
61
+ Adjusts start/end indices to exclude leading/trailing whitespace.
62
+ This ensures the HTML highlight is tight around the word.
63
+ """
64
+ # Extract the raw span using original indices
65
+ span = text[start:end]
66
+
67
+ # Shift start index forward if there is leading whitespace
68
+ while span and span[0].isspace():
69
+ start += 1
70
+ span = span[1:]
71
+
72
+ # Shift end index backward if there is trailing whitespace
73
+ while span and span[-1].isspace():
74
+ end -= 1
75
+ span = span[:-1]
76
+
77
+ return start, end, span
78
+
79
+ def save_current_entity(entity_parts: List[Dict], full_text: str, aggregated_entities: List[Dict]):
80
+ """
81
+ Finalizes a group of tokens into a single entity.
82
+ """
83
+ if not entity_parts:
84
+ return
85
+
86
+ # 1. Determine the raw range
87
+ raw_start = entity_parts[0]['start']
88
+ raw_end = entity_parts[-1]['end']
89
+
90
+ # 2. Refine boundaries (Trim spaces from indices)
91
+ final_start, final_end, clean_word = refine_boundaries(full_text, raw_start, raw_end)
92
+
93
+ if not clean_word:
94
+ return
95
+
96
+ # 3. Calculate score
97
+ avg_score = sum(part['score'] for part in entity_parts) / len(entity_parts)
98
+
99
+ # 4. Determine label (remove B/I prefix)
100
+ # We take the label from the first token usually, or the most frequent one
101
+ raw_label = entity_parts[0]['entity']
102
+ entity_group = raw_label.split('-')[-1] # e.g., "B-ORG" -> "ORG"
103
+
104
+ aggregated_entities.append({
105
+ 'word': clean_word,
106
+ 'score': float(avg_score),
107
+ 'entity_group': entity_group,
108
+ 'start': final_start,
109
+ 'end': final_end
110
+ })
111
+
112
+ def aggregate_entities_manual(ner_results: List[Dict], full_text: str) -> List[Dict]:
113
+ """
114
+ Aggregates subword tokens into whole entities.
115
+ Handles SentencePiece artifacts and BIO tagging.
116
+ """
117
+ if not ner_results:
118
+ return []
119
+
120
+ aggregated_entities = []
121
+ current_entity_parts = []
122
+
123
+ for entity in ner_results:
124
+ entity_label = entity['entity']
125
+
126
+ # Skip 'O' (Outside)
127
+ if entity_label == 'O':
128
+ if current_entity_parts:
129
+ save_current_entity(current_entity_parts, full_text, aggregated_entities)
130
+ current_entity_parts = []
131
+ continue
132
+
133
+ # Parse Label (e.g., "B-ORG", "I-ORG")
134
+ if '-' in entity_label:
135
+ prefix, label_type = entity_label.split('-', 1)
136
+ else:
137
+ prefix, label_type = None, entity_label
138
+
139
+ # Decision logic for merging
140
+ if not current_entity_parts:
141
+ # Start new entity
142
+ current_entity_parts.append(entity)
143
+ else:
144
+ prev_label = current_entity_parts[-1]['entity']
145
+ prev_type = prev_label.split('-')[-1] if '-' in prev_label else prev_label
146
+
147
+ # Merge condition:
148
+ # 1. Same Entity Type (ORG == ORG)
149
+ # 2. Adjacent indices (current start == prev end)
150
+ # 3. Logic: If it's "I-" tag, it MUST merge. If it's "B-" tag, it usually starts new,
151
+ # BUT some models are messy. We prioritize adjacency + type match for smoother highlighting.
152
+ if label_type == prev_type and entity['start'] == current_entity_parts[-1]['end']:
153
+ current_entity_parts.append(entity)
154
+ else:
155
+ # Close previous and start new
156
+ save_current_entity(current_entity_parts, full_text, aggregated_entities)
157
+ current_entity_parts = [entity]
158
+
159
+ # Save tail
160
+ if current_entity_parts:
161
+ save_current_entity(current_entity_parts, full_text, aggregated_entities)
162
+
163
+ return aggregated_entities
164
+
165
+ # Smart Processing Logic
166
+ def process_text_smart(text: str, pipe, tokenizer) -> List[Dict]:
167
+ """
168
+ Hybrid strategy: Direct inference for short texts, Sliding Window for long ones.
169
+ Returns RAW tokens (unaggregated).
170
+ """
171
+ tokenized = tokenizer(
172
+ text,
173
+ return_offsets_mapping=True,
174
+ add_special_tokens=False,
175
+ verbose=False
176
+ )
177
+ offsets = tokenized["offset_mapping"]
178
+ total_tokens = len(offsets)
179
+
180
+ # STRATEGY A: Short Text
181
+ if total_tokens <= SHORT_TEXT_THRESHOLD:
182
+ return pipe(text)
183
+
184
+ # STRATEGY B: Sliding Window
185
+ all_raw_tokens = []
186
+ step = MODEL_MAX_LENGTH - WINDOW_OVERLAP
187
+
188
+ for start_idx in range(0, total_tokens, step):
189
+ end_idx = min(start_idx + MODEL_MAX_LENGTH, total_tokens)
190
+
191
+ char_start = offsets[start_idx][0]
192
+ char_end = offsets[end_idx - 1][1]
193
+
194
+ chunk_text = text[char_start:char_end]
195
+ if not chunk_text.strip():
196
+ continue
197
+
198
+ chunk_results = pipe(chunk_text)
199
+
200
+ for ent in chunk_results:
201
+ ent["start"] += char_start
202
+ ent["end"] += char_start
203
+ all_raw_tokens.append(ent)
204
+
205
+ if end_idx == total_tokens:
206
+ break
207
+
208
+ # Deduplicate raw tokens based on start index
209
+ all_raw_tokens.sort(key=lambda x: x['start'])
210
+ unique_tokens = []
211
+ seen_indices = set()
212
+
213
+ for t in all_raw_tokens:
214
+ idx_key = (t['start'], t['end'])
215
+ if idx_key not in seen_indices:
216
+ unique_tokens.append(t)
217
+ seen_indices.add(idx_key)
218
+
219
+ return unique_tokens
220
+
221
+ # Lifespan
222
+ ml_models: Dict[str, Any] = {}
223
+
224
+ @asynccontextmanager
225
+ async def lifespan(app: FastAPI):
226
+ model_name = "rustemgareev/mdeberta-ner-ontonotes5"
227
+ logger.info(f"Loading model: {model_name}...")
228
+
229
+ try:
230
+ tokenizer = AutoTokenizer.from_pretrained(model_name, model_max_length=MODEL_MAX_LENGTH)
231
+ model = AutoModelForTokenClassification.from_pretrained(model_name)
232
+
233
+ ner_pipe = pipeline(
234
+ "ner",
235
+ model=model,
236
+ tokenizer=tokenizer,
237
+ aggregation_strategy="none",
238
+ device=-1
239
+ )
240
+
241
+ ml_models["ner"] = ner_pipe
242
+ ml_models["tokenizer"] = tokenizer
243
+ logger.info("Model loaded.")
244
+
245
+ except Exception as e:
246
+ logger.error(f"CRITICAL ERROR loading model: {e}")
247
+
248
+ yield
249
+ ml_models.clear()
250
+
251
+ # App Init
252
+ app = FastAPI(title="mDeBERTa NER API", version="3.3.0", lifespan=lifespan)
253
+
254
+ # API Endpoints
255
+ @app.post("/predict", response_model=NERResponse)
256
+ def predict(request: NERRequest):
257
+ if "ner" not in ml_models:
258
+ raise HTTPException(status_code=503, detail="Model not loaded")
259
+
260
+ if not request.text.strip():
261
+ return NERResponse(entities=[])
262
+
263
+ try:
264
+ # 1. Get Raw Tokens
265
+ raw_tokens = process_text_smart(
266
+ request.text,
267
+ ml_models["ner"],
268
+ ml_models["tokenizer"]
269
+ )
270
+
271
+ # 2. Aggressive Aggregation & Boundary Refinement
272
+ # We pass request.text to allow precise index trimming
273
+ aggregated = aggregate_entities_manual(raw_tokens, request.text)
274
+
275
+ return NERResponse(entities=[NEREntity(**item) for item in aggregated])
276
+
277
+ except Exception as e:
278
+ logger.error(f"Prediction error: {e}")
279
+ raise HTTPException(status_code=500, detail=str(e))
280
+
281
+ # Static Files
282
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi==0.128.0
2
+ pydantic==2.12.5
3
+ torch==2.9.1
4
+ transformers==4.57.3
5
+ uvicorn==0.40.0
6
+ aiofiles
static/index.html ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="utf-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>mdeberta-ner-ontonotes5</title>
7
+ <meta name="description" content="Named Entity Recognition Demo">
8
+ <style>
9
+ :root {
10
+ --c-misc: #f3e5f5;
11
+ --t-misc: #4a148c;
12
+ }
13
+
14
+ body {
15
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
16
+ max-width: 800px;
17
+ margin: 0 auto;
18
+ padding: 40px 20px;
19
+ background-color: #fafafa;
20
+ color: #111;
21
+ line-height: 1.6;
22
+ }
23
+
24
+ .container {
25
+ background: white;
26
+ padding: 40px;
27
+ border-radius: 12px;
28
+ box-shadow: 0 2px 8px rgba(0,0,0,0.04);
29
+ border: 1px solid #eaeaea;
30
+ }
31
+
32
+ header {
33
+ margin-bottom: 30px;
34
+ }
35
+
36
+ h1 {
37
+ font-size: 24px;
38
+ font-weight: 600;
39
+ margin: 0 0 8px 0;
40
+ letter-spacing: -0.02em;
41
+ }
42
+
43
+ .subtitle {
44
+ color: #666;
45
+ font-size: 14px;
46
+ }
47
+
48
+ /* Input Section */
49
+ .input-group {
50
+ margin-bottom: 20px;
51
+ position: relative;
52
+ }
53
+
54
+ label {
55
+ display: block;
56
+ font-size: 12px;
57
+ font-weight: 600;
58
+ text-transform: uppercase;
59
+ color: #888;
60
+ margin-bottom: 8px;
61
+ letter-spacing: 0.05em;
62
+ }
63
+
64
+ textarea {
65
+ width: 100%;
66
+ min-height: 120px;
67
+ padding: 16px;
68
+ border: 1px solid #ddd;
69
+ border-radius: 8px;
70
+ font-size: 16px;
71
+ line-height: 1.6;
72
+ resize: vertical;
73
+ background-color: #fff;
74
+ box-sizing: border-box;
75
+ font-family: inherit;
76
+ outline: none;
77
+ transition: border-color 0.2s, box-shadow 0.2s;
78
+ color: #333;
79
+ }
80
+
81
+ textarea:focus {
82
+ border-color: #000;
83
+ box-shadow: 0 0 0 2px rgba(0,0,0,0.05);
84
+ }
85
+
86
+ /* Button */
87
+ button.main-btn {
88
+ background-color: #1a1a1a;
89
+ color: white;
90
+ border: none;
91
+ padding: 14px 32px;
92
+ border-radius: 8px;
93
+ font-family: inherit;
94
+ font-size: 15px;
95
+ font-weight: 500;
96
+ cursor: pointer;
97
+ display: block;
98
+ width: 100%;
99
+ transition: background-color 0.2s, transform 0.1s;
100
+ }
101
+
102
+ button.main-btn:hover {
103
+ background-color: #333;
104
+ }
105
+
106
+ button.main-btn:active {
107
+ transform: scale(0.99);
108
+ }
109
+
110
+ button.main-btn:disabled {
111
+ background-color: #ccc;
112
+ cursor: not-allowed;
113
+ transform: none;
114
+ }
115
+
116
+ /* Output Section */
117
+ #output-wrapper {
118
+ margin-top: 30px;
119
+ display: none;
120
+ animation: fadeIn 0.3s ease-out;
121
+ }
122
+
123
+ .result-box {
124
+ padding: 20px;
125
+ border: 1px solid #eee;
126
+ background-color: #fcfcfc;
127
+ border-radius: 8px;
128
+ font-size: 16px;
129
+ line-height: 1.8;
130
+ white-space: pre-wrap;
131
+ }
132
+
133
+ /* Entity Styling */
134
+ .entity {
135
+ padding: 2px 6px;
136
+ border-radius: 4px;
137
+ font-weight: 500;
138
+ cursor: help;
139
+ position: relative;
140
+ transition: background-color 0.2s;
141
+ box-decoration-break: clone;
142
+ -webkit-box-decoration-break: clone;
143
+ }
144
+
145
+ /* Tooltip */
146
+ .entity::after {
147
+ content: attr(data-label) " " attr(data-score);
148
+ position: absolute;
149
+ bottom: 100%;
150
+ left: 50%;
151
+ transform: translateX(-50%) translateY(-4px);
152
+ background: #1a1a1a;
153
+ color: #fff;
154
+ padding: 4px 8px;
155
+ border-radius: 4px;
156
+ font-size: 11px;
157
+ white-space: nowrap;
158
+ opacity: 0;
159
+ pointer-events: none;
160
+ transition: opacity 0.2s, transform 0.2s;
161
+ z-index: 10;
162
+ font-weight: 400;
163
+ }
164
+
165
+ .entity:hover::after {
166
+ opacity: 1;
167
+ transform: translateX(-50%) translateY(-8px);
168
+ }
169
+
170
+ .type-DEFAULT { background: var(--c-misc); color: var(--t-misc); }
171
+
172
+ @keyframes fadeIn {
173
+ from { opacity: 0; transform: translateY(10px); }
174
+ to { opacity: 1; transform: translateY(0); }
175
+ }
176
+
177
+ .error-msg {
178
+ color: #d32f2f;
179
+ background: #ffebee;
180
+ padding: 12px;
181
+ border-radius: 6px;
182
+ margin-top: 20px;
183
+ font-size: 14px;
184
+ display: none;
185
+ }
186
+
187
+ @media (max-width: 600px) {
188
+ body { padding: 20px 10px; }
189
+ .container { padding: 24px; }
190
+ }
191
+ </style>
192
+ </head>
193
+ <body>
194
+
195
+ <div class="container">
196
+ <header>
197
+ <h1>mdeberta-ner-ontonotes5</h1>
198
+ <div class="subtitle">Named Entity Recognition Demo</div>
199
+ </header>
200
+
201
+ <div class="input-group">
202
+ <label for="inputText">Input Text</label>
203
+ <textarea id="inputText" placeholder="Enter text to analyze...">Apple Inc. is looking at buying a U.K. startup for $1 billion in London next week.</textarea>
204
+ </div>
205
+
206
+ <button id="analyzeBtn" class="main-btn" onclick="analyze()">Analyze Text</button>
207
+
208
+ <div id="errorBox" class="error-msg"></div>
209
+
210
+ <div id="output-wrapper">
211
+ <label>Result</label>
212
+ <div id="resultBox" class="result-box"></div>
213
+ </div>
214
+ </div>
215
+
216
+ <script>
217
+ async function analyze() {
218
+ const input = document.getElementById('inputText');
219
+ const btn = document.getElementById('analyzeBtn');
220
+ const outputWrapper = document.getElementById('output-wrapper');
221
+ const resultBox = document.getElementById('resultBox');
222
+ const errorBox = document.getElementById('errorBox');
223
+
224
+ const text = input.value.trim();
225
+ if (!text) return;
226
+
227
+ // Reset UI
228
+ btn.disabled = true;
229
+ btn.textContent = "Processing...";
230
+ errorBox.style.display = 'none';
231
+ outputWrapper.style.display = 'none';
232
+
233
+ try {
234
+ const response = await fetch('/predict', {
235
+ method: 'POST',
236
+ headers: { 'Content-Type': 'application/json' },
237
+ body: JSON.stringify({ text: text })
238
+ });
239
+
240
+ if (!response.ok) {
241
+ throw new Error(`Server Error: ${response.statusText}`);
242
+ }
243
+
244
+ const data = await response.json();
245
+ renderResult(text, data.entities);
246
+
247
+ // Show result
248
+ outputWrapper.style.display = 'block';
249
+
250
+ } catch (err) {
251
+ errorBox.textContent = err.message;
252
+ errorBox.style.display = 'block';
253
+ } finally {
254
+ btn.disabled = false;
255
+ btn.textContent = "Analyze Text";
256
+ }
257
+ }
258
+
259
+ function renderResult(originalText, entities) {
260
+ const resultBox = document.getElementById('resultBox');
261
+ resultBox.innerHTML = '';
262
+
263
+ if (!entities || entities.length === 0) {
264
+ resultBox.textContent = originalText;
265
+ return;
266
+ }
267
+
268
+ let lastIndex = 0;
269
+
270
+ entities.forEach(entity => {
271
+ // 1. Text before entity
272
+ const plainText = originalText.slice(lastIndex, entity.start);
273
+ resultBox.appendChild(document.createTextNode(plainText));
274
+
275
+ // 2. Entity
276
+ const span = document.createElement('span');
277
+
278
+ // Determine class based on entity group (fallback to DEFAULT)
279
+ const type = entity.entity_group || 'DEFAULT';
280
+ // Check if specific class exists in CSS is hard in JS, so we rely on CSS fallbacks or generic logic
281
+ // Here we map common groups to classes, others fall to DEFAULT via CSS if not defined
282
+ span.className = `entity type-${type} type-DEFAULT`;
283
+
284
+ span.textContent = originalText.slice(entity.start, entity.end);
285
+
286
+ // Tooltip data
287
+ span.setAttribute('data-label', type);
288
+ span.setAttribute('data-score', Math.round(entity.score * 100) + '%');
289
+
290
+ resultBox.appendChild(span);
291
+
292
+ lastIndex = entity.end;
293
+ });
294
+
295
+ // 3. Remaining text
296
+ resultBox.appendChild(document.createTextNode(originalText.slice(lastIndex)));
297
+ }
298
+ </script>
299
+
300
+ </body>
301
+ </html>