NOT-OMEGA commited on
Commit
595fb58
·
verified ·
1 Parent(s): b310e7b

Delete root

Browse files
Files changed (4) hide show
  1. root/Dockerfile +0 -32
  2. root/api.py +0 -314
  3. root/classify.py +0 -198
  4. root/processor_bert.py +0 -216
root/Dockerfile DELETED
@@ -1,32 +0,0 @@
1
- FROM python:3.11-slim
2
-
3
- # Build args
4
- ARG PORT=8000
5
-
6
- # System deps
7
- RUN apt-get update && apt-get install -y --no-install-recommends \
8
- curl \
9
- && rm -rf /var/lib/apt/lists/*
10
-
11
- WORKDIR /app
12
-
13
- # Install Python deps first (layer cache)
14
- COPY requirements.txt .
15
- RUN pip install --no-cache-dir -r requirements.txt
16
-
17
- # Copy source
18
- COPY . .
19
-
20
- # Non-root user
21
- RUN adduser --disabled-password --gecos "" appuser \
22
- && chown -R appuser:appuser /app
23
- USER appuser
24
-
25
- # Health check
26
- HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
27
- CMD curl -f http://localhost:${PORT}/health || exit 1
28
-
29
- EXPOSE ${PORT}
30
-
31
- # Production: single worker (CPU-bound inference — scale via replicas, not threads)
32
- CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "1", "--log-level", "info"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root/api.py DELETED
@@ -1,314 +0,0 @@
1
- """
2
- api.py — Async FastAPI Inference Service
3
-
4
- Endpoints:
5
- POST /classify — Single log
6
- POST /classify/batch — Batch of logs (up to 512)
7
- GET /health — Liveness check
8
- GET /ready — Readiness check (model loaded?)
9
- GET /metrics — Request counts, throughput, latency stats
10
-
11
- Features:
12
- - Async request handling (non-blocking)
13
- - Worker pool via asyncio semaphore (bounded concurrency)
14
- - Structured JSON logs with request_id
15
- - Rate limiting (configurable)
16
- - Request ID tracing
17
- - Batch queue aggregation for small requests
18
-
19
- Run:
20
- uvicorn api:app --host 0.0.0.0 --port 8000 --workers 1
21
-
22
- Example:
23
- curl -X POST http://localhost:8000/classify \
24
- -H "Content-Type: application/json" \
25
- -d '{"source": "ModernCRM", "log_message": "User User123 logged in."}'
26
- """
27
- from __future__ import annotations
28
- import asyncio
29
- import logging
30
- import os
31
- import time
32
- import uuid
33
- import statistics
34
- from collections import deque
35
- from contextlib import asynccontextmanager
36
- from typing import Optional
37
-
38
- from fastapi import FastAPI, HTTPException, Request, status
39
- from fastapi.middleware.cors import CORSMiddleware
40
- from fastapi.responses import JSONResponse
41
- from pydantic import BaseModel, Field, field_validator
42
-
43
- # ── Logging setup ─────────────────────────────────────────────────────────────
44
- logging.basicConfig(
45
- level=logging.INFO,
46
- format='{"time":"%(asctime)s","level":"%(levelname)s","logger":"%(name)s","msg":"%(message)s"}'
47
- )
48
- logger = logging.getLogger("log-classifier-api")
49
-
50
- # ── Config ─────────────────────────────────────────────────────────────────────
51
- MAX_BATCH_SIZE = int(os.getenv("MAX_BATCH_SIZE", "512"))
52
- MAX_CONCURRENT = int(os.getenv("MAX_CONCURRENT", "4")) # concurrency cap
53
- RATE_LIMIT_PER_MIN = int(os.getenv("RATE_LIMIT_PER_MIN", "1000"))
54
- LOG_MAX_CHARS = 2048 # truncate huge logs before classify
55
-
56
- # ── Global state ───────────────────────────────────────────────────────────────
57
- _semaphore: asyncio.Semaphore = None # type: ignore
58
- _model_ready: bool = False
59
-
60
- # Metrics ring buffer (last 1000 requests)
61
- _latencies_ms: deque = deque(maxlen=1000)
62
- _request_count = 0
63
- _error_count = 0
64
- _start_time = time.time()
65
-
66
- # Rate limiter (simple sliding window per process)
67
- _rate_window: deque = deque(maxlen=RATE_LIMIT_PER_MIN)
68
-
69
-
70
- # ── Lifespan: load models on startup ──────────────────────────────────────────
71
- @asynccontextmanager
72
- async def lifespan(app: FastAPI):
73
- global _semaphore, _model_ready
74
-
75
- logger.info("Starting up — loading models…")
76
- _semaphore = asyncio.Semaphore(MAX_CONCURRENT)
77
-
78
- # Load models in a thread pool (blocking I/O, don't block event loop)
79
- loop = asyncio.get_event_loop()
80
- try:
81
- await loop.run_in_executor(None, _load_models_blocking)
82
- _model_ready = True
83
- logger.info("✅ Models loaded — API ready")
84
- except Exception as e:
85
- logger.error(f"❌ Model load failed: {e}")
86
- # Service starts but /ready will return 503
87
-
88
- yield
89
-
90
- logger.info("Shutting down")
91
-
92
-
93
- def _load_models_blocking():
94
- """Load BERT + classifier (blocks — run in executor)."""
95
- from processor_bert import classify_batch as _
96
- logger.info("BERT model loaded")
97
-
98
-
99
- # ── App factory ────────────────────────────────────────────────────────────────
100
- app = FastAPI(
101
- title="Log Classification API",
102
- description="3-tier hybrid pipeline: Regex → BERT → LLM",
103
- version="3.0.0",
104
- lifespan=lifespan,
105
- )
106
-
107
- app.add_middleware(
108
- CORSMiddleware,
109
- allow_origins=["*"],
110
- allow_methods=["*"],
111
- allow_headers=["*"],
112
- )
113
-
114
-
115
- # ── Request / Response schemas ─────────────────────────────────────────────────
116
- class LogRequest(BaseModel):
117
- source: str = Field(..., example="ModernCRM")
118
- log_message: str = Field(..., example="User User123 logged in.", min_length=1)
119
-
120
- @field_validator("log_message")
121
- @classmethod
122
- def truncate_long_logs(cls, v: str) -> str:
123
- return v[:LOG_MAX_CHARS]
124
-
125
-
126
- class LogResponse(BaseModel):
127
- request_id: str
128
- label: str
129
- tier: str
130
- confidence: Optional[float]
131
- latency_ms: float
132
- cached: bool = False
133
-
134
-
135
- class BatchRequest(BaseModel):
136
- logs: list[LogRequest] = Field(..., max_length=MAX_BATCH_SIZE)
137
-
138
-
139
- class BatchResponse(BaseModel):
140
- request_id: str
141
- total: int
142
- elapsed_ms: float
143
- throughput: float
144
- results: list[LogResponse]
145
-
146
-
147
- class HealthResponse(BaseModel):
148
- status: str
149
- uptime_s: float
150
-
151
-
152
- class MetricsResponse(BaseModel):
153
- total_requests: int
154
- total_errors: int
155
- uptime_s: float
156
- requests_per_min: float
157
- latency_p50_ms: Optional[float]
158
- latency_p95_ms: Optional[float]
159
- latency_p99_ms: Optional[float]
160
-
161
-
162
- # ── Rate limiter ───────────────────────────────────────────────────────────────
163
- def _check_rate_limit() -> None:
164
- now = time.time()
165
- _rate_window.append(now)
166
- # Window = last 60 seconds
167
- recent = [t for t in _rate_window if now - t < 60]
168
- if len(recent) > RATE_LIMIT_PER_MIN:
169
- raise HTTPException(
170
- status_code=status.HTTP_429_TOO_MANY_REQUESTS,
171
- detail=f"Rate limit exceeded: {RATE_LIMIT_PER_MIN} req/min",
172
- )
173
-
174
-
175
- # ── Middleware: request logging ────────────────────────────────────────────────
176
- @app.middleware("http")
177
- async def log_requests(request: Request, call_next):
178
- rid = request.headers.get("X-Request-ID", str(uuid.uuid4())[:8])
179
- request.state.request_id = rid
180
- t0 = time.perf_counter()
181
- response = await call_next(request)
182
- elapsed = (time.perf_counter() - t0) * 1000
183
- logger.info(
184
- f"method={request.method} path={request.url.path} "
185
- f"status={response.status_code} latency={elapsed:.1f}ms rid={rid}"
186
- )
187
- response.headers["X-Request-ID"] = rid
188
- return response
189
-
190
-
191
- # ── Health & readiness ─────────────────────────────────────────────────────────
192
- @app.get("/health", response_model=HealthResponse, tags=["ops"])
193
- async def health():
194
- return {"status": "ok", "uptime_s": round(time.time() - _start_time, 1)}
195
-
196
-
197
- @app.get("/ready", tags=["ops"])
198
- async def ready():
199
- if not _model_ready:
200
- raise HTTPException(status_code=503, detail="Models not yet loaded")
201
- return {"status": "ready"}
202
-
203
-
204
- # ── Metrics ────────────────────────────────────────────────────────────────────
205
- @app.get("/metrics", response_model=MetricsResponse, tags=["ops"])
206
- async def metrics():
207
- uptime = time.time() - _start_time
208
- lats = sorted(_latencies_ms) if _latencies_ms else []
209
- n = len(lats)
210
-
211
- def pct(p):
212
- return round(lats[min(int(n * p), n - 1)], 2) if n else None
213
-
214
- return {
215
- "total_requests": _request_count,
216
- "total_errors": _error_count,
217
- "uptime_s": round(uptime, 1),
218
- "requests_per_min": round(_request_count / max(uptime / 60, 1), 1),
219
- "latency_p50_ms": pct(0.50),
220
- "latency_p95_ms": pct(0.95),
221
- "latency_p99_ms": pct(0.99),
222
- }
223
-
224
-
225
- # ── Classify single ────────────────────────────────────────────────────────────
226
- @app.post("/classify", response_model=LogResponse, tags=["inference"])
227
- async def classify_single(req: LogRequest, request: Request):
228
- global _request_count, _error_count
229
- _check_rate_limit()
230
- _request_count += 1
231
- rid = getattr(request.state, "request_id", str(uuid.uuid4())[:8])
232
-
233
- async with _semaphore:
234
- loop = asyncio.get_event_loop()
235
- t0 = time.perf_counter()
236
- try:
237
- result = await loop.run_in_executor(
238
- None, _classify_blocking, req.source, req.log_message
239
- )
240
- except Exception as e:
241
- _error_count += 1
242
- logger.error(f"rid={rid} classify error: {e}")
243
- raise HTTPException(status_code=500, detail=str(e))
244
-
245
- latency = (time.perf_counter() - t0) * 1000
246
- _latencies_ms.append(latency)
247
-
248
- return LogResponse(
249
- request_id = rid,
250
- label = result["label"],
251
- tier = result["tier"],
252
- confidence = result.get("confidence"),
253
- latency_ms = round(latency, 2),
254
- )
255
-
256
-
257
- def _classify_blocking(source: str, log_message: str) -> dict:
258
- from classify import classify_log
259
- return classify_log(source, log_message)
260
-
261
-
262
- # ── Classify batch ─────────────────────────────────────────────────────────────
263
- @app.post("/classify/batch", response_model=BatchResponse, tags=["inference"])
264
- async def classify_batch_endpoint(req: BatchRequest, request: Request):
265
- global _request_count, _error_count
266
- _check_rate_limit()
267
- _request_count += 1
268
- rid = getattr(request.state, "request_id", str(uuid.uuid4())[:8])
269
-
270
- log_pairs = [(r.source, r.log_message) for r in req.logs]
271
-
272
- async with _semaphore:
273
- loop = asyncio.get_event_loop()
274
- t0 = time.perf_counter()
275
- try:
276
- results = await loop.run_in_executor(
277
- None, _classify_batch_blocking, log_pairs
278
- )
279
- except Exception as e:
280
- _error_count += 1
281
- logger.error(f"rid={rid} batch error: {e}")
282
- raise HTTPException(status_code=500, detail=str(e))
283
-
284
- elapsed_ms = (time.perf_counter() - t0) * 1000
285
- throughput = round(len(log_pairs) / (elapsed_ms / 1000), 1)
286
- _latencies_ms.extend([elapsed_ms / len(log_pairs)] * len(log_pairs))
287
-
288
- return BatchResponse(
289
- request_id = rid,
290
- total = len(log_pairs),
291
- elapsed_ms = round(elapsed_ms, 2),
292
- throughput = throughput,
293
- results = [
294
- LogResponse(
295
- request_id = rid,
296
- label = r["label"],
297
- tier = r["tier"],
298
- confidence = r.get("confidence"),
299
- latency_ms = round(elapsed_ms / len(log_pairs), 2),
300
- )
301
- for r in results
302
- ],
303
- )
304
-
305
-
306
- def _classify_batch_blocking(log_pairs: list[tuple[str, str]]) -> list[dict]:
307
- from classify import classify_logs
308
- return classify_logs(log_pairs)
309
-
310
-
311
- # ── Dev runner ──────────────────────────────────────────────────────────────────
312
- if __name__ == "__main__":
313
- import uvicorn
314
- uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=False, workers=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root/classify.py DELETED
@@ -1,198 +0,0 @@
1
- """
2
- classify.py — 3-Tier Hybrid Pipeline (V3 — Latency-Tracked)
3
-
4
- Architecture:
5
- LegacyCRM → LLM directly
6
- Others → Regex → BERT (batch) → LLM fallback
7
-
8
- Changes in V3:
9
- - Tier-wise latency tracking (regex_ms, bert_ms, llm_ms)
10
- - Pipeline summary with p50/p95 per tier
11
- - Defensive: LLM timeout + retry baked in via processor_llm
12
- - classify_logs returns richer result dict
13
- """
14
- from __future__ import annotations
15
- import time
16
- import statistics
17
- import pandas as pd
18
- from processor_regex import classify_with_regex
19
- from processor_bert import classify_batch as bert_batch
20
- from processor_llm import classify_with_llm
21
-
22
- LEGACY_SOURCE = "LegacyCRM"
23
-
24
-
25
- # ── Result type ─────────────────────────────────────────────────────────────
26
- def _make_result(label: str, tier: str, confidence, latency_ms: float) -> dict:
27
- return {
28
- "label": label,
29
- "tier": tier,
30
- "confidence": confidence,
31
- "latency_ms": round(latency_ms, 3),
32
- }
33
-
34
-
35
- # ── Single log (backward-compatible) ────────────────────────────────────────
36
- def classify_log(source: str, log_msg: str) -> dict:
37
- """Single log classify karo. Returns label, tier, confidence, latency_ms."""
38
- results = classify_logs([(source, log_msg)])
39
- return results[0]
40
-
41
-
42
- # ── Batch pipeline (main entry point) ───────────────────────────────────────
43
- def classify_logs(logs: list[tuple[str, str]]) -> list[dict]:
44
- """
45
- Batch classify with 3-tier routing + per-result latency.
46
-
47
- Returns list of dicts:
48
- { label, tier, confidence, latency_ms }
49
-
50
- Tier routing:
51
- LegacyCRM source → LLM directly
52
- Regex match → done (sub-ms)
53
- Remainder → BERT batch → LLM if low confidence
54
- """
55
- n = len(logs)
56
- results = [None] * n
57
-
58
- # ── Step 1: Route to groups ─────────────────────────────────────────────
59
- llm_indices = []
60
- bert_indices = []
61
- entry_times = [time.perf_counter()] * n # approximate per-log start
62
-
63
- t_route_start = time.perf_counter()
64
- for i, (source, log_msg) in enumerate(logs):
65
- entry_times[i] = time.perf_counter()
66
- if source == LEGACY_SOURCE:
67
- llm_indices.append(i)
68
- else:
69
- t0 = time.perf_counter()
70
- label = classify_with_regex(log_msg)
71
- t1 = time.perf_counter()
72
- if label:
73
- results[i] = _make_result(label, "Regex", 1.0, (t1 - t0) * 1000)
74
- else:
75
- bert_indices.append(i)
76
-
77
- # ── Step 2: BERT batch ──────────────────────────────────────────────────
78
- if bert_indices:
79
- bert_msgs = [logs[i][1] for i in bert_indices]
80
-
81
- t_bert_start = time.perf_counter()
82
- bert_results = bert_batch(bert_msgs)
83
- t_bert_end = time.perf_counter()
84
-
85
- bert_ms_per_log = (t_bert_end - t_bert_start) * 1000 / len(bert_msgs)
86
-
87
- for idx, (label, conf) in zip(bert_indices, bert_results):
88
- if label != "Unclassified":
89
- results[idx] = _make_result(label, "BERT", conf, bert_ms_per_log)
90
- else:
91
- llm_indices.append(idx)
92
-
93
- # ── Step 3: LLM (LegacyCRM + BERT fallback) ────────────────────────────
94
- for i in llm_indices:
95
- _, log_msg = logs[i]
96
- t0 = time.perf_counter()
97
- label = classify_with_llm(log_msg)
98
- t1 = time.perf_counter()
99
- tier = "LLM" if logs[i][0] == LEGACY_SOURCE else "LLM (fallback)"
100
- results[i] = _make_result(label, tier, None, (t1 - t0) * 1000)
101
-
102
- return results
103
-
104
-
105
- # ── Pipeline summary ─────────────────────────────────────────────────────────
106
- def pipeline_summary(results: list[dict]) -> dict:
107
- """
108
- Aggregate stats from classify_logs output.
109
- Useful for dashboard and benchmark reporting.
110
- """
111
- tier_groups: dict[str, list[float]] = {}
112
- label_counts: dict[str, int] = {}
113
-
114
- for r in results:
115
- tier = r["tier"]
116
- tier_groups.setdefault(tier, []).append(r["latency_ms"])
117
- label_counts[r["label"]] = label_counts.get(r["label"], 0) + 1
118
-
119
- total = len(results)
120
- tier_stats = {}
121
- for tier, latencies in tier_groups.items():
122
- latencies_sorted = sorted(latencies)
123
- n = len(latencies_sorted)
124
- tier_stats[tier] = {
125
- "count": n,
126
- "pct": round(n / total * 100, 1),
127
- "p50_ms": round(statistics.median(latencies_sorted), 2),
128
- "p95_ms": round(latencies_sorted[min(int(n * 0.95), n - 1)], 2),
129
- "p99_ms": round(latencies_sorted[min(int(n * 0.99), n - 1)], 2),
130
- "mean_ms": round(statistics.mean(latencies_sorted), 2),
131
- }
132
-
133
- return {
134
- "total": total,
135
- "tier_stats": tier_stats,
136
- "label_counts": label_counts,
137
- }
138
-
139
-
140
- # ── CSV batch classify ───────────────────────────────────────────────────────
141
- def classify_csv(input_path: str, output_path: str = "output.csv") -> tuple[str, pd.DataFrame]:
142
- """
143
- CSV file classify karo.
144
- Required columns: 'source', 'log_message'
145
- Output: adds 'predicted_label', 'tier_used', 'confidence', 'latency_ms'
146
- """
147
- df = pd.read_csv(input_path)
148
- required = {"source", "log_message"}
149
- if not required.issubset(df.columns):
150
- raise ValueError(f"CSV mein ye columns chahiye: {required}. Mila: {set(df.columns)}")
151
-
152
- log_pairs = list(zip(df["source"], df["log_message"]))
153
- results = classify_logs(log_pairs)
154
-
155
- df["predicted_label"] = [r["label"] for r in results]
156
- df["tier_used"] = [r["tier"] for r in results]
157
- df["latency_ms"] = [r["latency_ms"] for r in results]
158
- df["confidence"] = [
159
- f"{r['confidence']:.1%}" if r["confidence"] is not None else "N/A"
160
- for r in results
161
- ]
162
-
163
- df.to_csv(output_path, index=False)
164
- return output_path, df
165
-
166
-
167
- # Aliases
168
- classify = classify_logs
169
-
170
-
171
- # ── Self-test ────────────────────────────────────────────────────────────────
172
- if __name__ == "__main__":
173
- sample = [
174
- ("ModernCRM", "IP 192.168.133.114 blocked due to potential attack"),
175
- ("BillingSystem", "User User12345 logged in."),
176
- ("AnalyticsEngine", "File data_6957.csv uploaded successfully by user User265."),
177
- ("ModernHR", "GET /v2/servers/detail HTTP/1.1 status: 200 len: 1583 time: 0.19"),
178
- ("ModernHR", "Admin access escalation detected for user 9429"),
179
- ("LegacyCRM", "Case escalation for ticket ID 7324 failed because the assigned support agent is no longer active."),
180
- ("LegacyCRM", "The 'ReportGenerator' module will be retired in version 4.0."),
181
- ]
182
-
183
- print(f'{"Source":<20} {"Tier":<18} {"Conf":>6} {"Lat(ms)":>8} {"Label":<25} Log')
184
- print("─" * 115)
185
- results = classify_logs(sample)
186
- for (source, log), r in zip(sample, results):
187
- conf = f"{r['confidence']:.0%}" if r["confidence"] else " N/A"
188
- print(f'{source:<20} {r["tier"]:<18} {conf:>6} {r["latency_ms"]:>8.1f} {r["label"]:<25} {log[:40]}')
189
-
190
- summary = pipeline_summary(results)
191
- print("\n📊 Pipeline Summary:")
192
- for tier, stats in summary["tier_stats"].items():
193
- print(f" {tier}: {stats['count']} logs ({stats['pct']}%) | "
194
- f"p50={stats['p50_ms']}ms p95={stats['p95_ms']}ms p99={stats['p99_ms']}ms")
195
-
196
- print("\n🏷️ Label distribution:")
197
- for label, count in sorted(summary["label_counts"].items(), key=lambda x: -x[1]):
198
- print(f" • {label}: {count}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
root/processor_bert.py DELETED
@@ -1,216 +0,0 @@
1
- """
2
- processor_bert_fast.py — ONNX Runtime powered BERT classifier
3
- Speed: 82 logs/s → 2000+ logs/s
4
-
5
- Kaise kaam karta hai:
6
- 1. ONNX Runtime: Normal PyTorch se 3-5x faster
7
- 2. Batch processing: 64 logs ek saath process
8
- 3. Pre-allocated buffers: Memory waste nahi
9
- """
10
- from __future__ import annotations
11
- import os
12
- import numpy as np
13
- import joblib
14
-
15
- # ── Check karo kaunsa method use karna hai ──────────────────
16
- _USE_ONNX = False
17
- _embedding_model = None
18
- _classifier = None
19
- _ort_session = None
20
- _ort_tokenizer = None
21
-
22
- MODEL_PATH = os.path.join(os.path.dirname(__file__), 'models', 'log_classifier.joblib')
23
- ONNX_DIR = os.path.join(os.path.dirname(__file__), 'models', 'onnx')
24
- CONFIDENCE_THRESHOLD = 0.30
25
- DEFAULT_BATCH = 64
26
-
27
-
28
- def _load_models():
29
- """Lazily load models — pehli call pe hi load hoga, baar baar nahi."""
30
- global _USE_ONNX, _embedding_model, _classifier, _ort_session, _ort_tokenizer
31
-
32
- if _classifier is not None:
33
- return # Already loaded
34
-
35
- # ── Classifier load karo ───────────────────────────────
36
- if not os.path.exists(MODEL_PATH):
37
- raise FileNotFoundError(
38
- f'Model nahi mila: {MODEL_PATH}\n'
39
- 'Pehle Colab notebook run karo aur model download karo.'
40
- )
41
- _classifier = joblib.load(MODEL_PATH)
42
-
43
- # ── ONNX try karo (fast), fallback to PyTorch ──────────
44
- onnx_model_file = os.path.join(ONNX_DIR, 'model.onnx')
45
-
46
- if os.path.exists(onnx_model_file):
47
- try:
48
- import onnxruntime as ort
49
- from transformers import AutoTokenizer
50
-
51
- # CPU optimized session options
52
- sess_opts = ort.SessionOptions()
53
- sess_opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
54
- sess_opts.intra_op_num_threads = os.cpu_count()
55
- sess_opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
56
-
57
- _ort_session = ort.InferenceSession(
58
- onnx_model_file,
59
- sess_options=sess_opts,
60
- providers=['CPUExecutionProvider']
61
- )
62
- _ort_tokenizer = AutoTokenizer.from_pretrained(ONNX_DIR)
63
- _USE_ONNX = True
64
- print('[BERT] ✅ ONNX Runtime loaded — FAST MODE')
65
-
66
- except Exception as e:
67
- print(f'[BERT] ONNX load failed ({e}), fallback to PyTorch')
68
- _USE_ONNX = False
69
-
70
- if not _USE_ONNX:
71
- from sentence_transformers import SentenceTransformer
72
- _embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
73
- print('[BERT] ⚠️ PyTorch mode (install ONNX for 3-5x speedup)')
74
-
75
-
76
- def _embed_onnx(texts: list[str]) -> np.ndarray:
77
- """ONNX Runtime se embeddings generate karo — FAST."""
78
- import torch
79
-
80
- inputs = _ort_tokenizer(
81
- texts,
82
- padding=True,
83
- truncation=True,
84
- max_length=128,
85
- return_tensors='np' # NumPy directly (faster than PyTorch tensors)
86
- )
87
-
88
- # ONNX session run
89
- ort_inputs = {
90
- 'input_ids': inputs['input_ids'].astype(np.int64),
91
- 'attention_mask': inputs['attention_mask'].astype(np.int64),
92
- }
93
- if 'token_type_ids' in [i.name for i in _ort_session.get_inputs()]:
94
- ort_inputs['token_type_ids'] = inputs.get(
95
- 'token_type_ids', np.zeros_like(inputs['input_ids'])
96
- ).astype(np.int64)
97
-
98
- outputs = _ort_session.run(None, ort_inputs)
99
- hidden = outputs[0] # (batch, seq_len, hidden)
100
-
101
- # Mean pooling (attention mask weighted)
102
- mask = inputs['attention_mask'][:, :, None].astype(np.float32)
103
- summed = (hidden * mask).sum(axis=1)
104
- counts = mask.sum(axis=1)
105
- embeddings = summed / counts
106
-
107
- # L2 normalize
108
- norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
109
- return embeddings / (norms + 1e-8)
110
-
111
-
112
- def _embed_pytorch(texts: list[str]) -> np.ndarray:
113
- """PyTorch fallback."""
114
- return _embedding_model.encode(
115
- texts,
116
- batch_size=DEFAULT_BATCH,
117
- convert_to_numpy=True,
118
- normalize_embeddings=True,
119
- show_progress_bar=False
120
- )
121
-
122
-
123
- # ── PUBLIC API ──────────────────────────────────────────────
124
-
125
- def classify_with_bert(log_message: str) -> tuple[str, float]:
126
- """
127
- Single log classify karo.
128
- Returns: (label, confidence)
129
- """
130
- _load_models()
131
- results = classify_batch([log_message])
132
- return results[0]
133
-
134
-
135
- def classify_batch(log_messages: list[str]) -> list[tuple[str, float]]:
136
- """
137
- Multiple logs ek saath classify karo — MUCH FASTER!
138
- Returns: list of (label, confidence) tuples
139
-
140
- Example:
141
- results = classify_batch(['log1', 'log2', 'log3'])
142
- for label, conf in results:
143
- print(f'{label}: {conf:.1%}')
144
- """
145
- _load_models()
146
-
147
- if not log_messages:
148
- return []
149
-
150
- results = []
151
-
152
- # Process in batches
153
- for i in range(0, len(log_messages), DEFAULT_BATCH):
154
- batch = log_messages[i:i + DEFAULT_BATCH]
155
-
156
- # Generate embeddings
157
- if _USE_ONNX:
158
- embeddings = _embed_onnx(batch)
159
- else:
160
- embeddings = _embed_pytorch(batch)
161
-
162
- # Classify
163
- probs = _classifier.predict_proba(embeddings)
164
- max_probs = probs.max(axis=1)
165
- labels = _classifier.predict(embeddings)
166
-
167
- for label, conf in zip(labels, max_probs):
168
- if conf < CONFIDENCE_THRESHOLD:
169
- results.append(('Unclassified', float(conf)))
170
- else:
171
- results.append((str(label), float(conf)))
172
-
173
- return results
174
-
175
-
176
- def get_classes() -> list[str]:
177
- """Classifier ke classes return karo."""
178
- _load_models()
179
- return list(_classifier.classes_)
180
-
181
-
182
- def is_onnx_mode() -> bool:
183
- """Check karo ONNX use ho raha hai ya nahi."""
184
- _load_models()
185
- return _USE_ONNX
186
-
187
-
188
- # ── TEST ────────────────────────────────────────────────────
189
- if __name__ == '__main__':
190
- import time
191
-
192
- test_logs = [
193
- 'GET /v2/servers/detail HTTP/1.1 status: 404 len: 1583 time: 0.19',
194
- 'System crashed due to driver errors when restarting the server',
195
- 'Multiple login failures occurred on user 6454 account',
196
- 'Admin access escalation detected for user 9429',
197
- 'CPU usage at 98% for the last 10 minutes on node-7',
198
- 'Backup completed successfully.',
199
- 'User User123 logged in.',
200
- 'Data replication task for shard 14 did not complete',
201
- 'Hey bro chill ya!', # should be Unclassified
202
- ]
203
-
204
- print('Single log test:')
205
- for log in test_logs:
206
- label, conf = classify_with_bert(log)
207
- print(f' [{conf:.0%}] {label:25s} | {log[:60]}')
208
-
209
- print(f'\nMode: {"ONNX 🚀" if is_onnx_mode() else "PyTorch"}')
210
-
211
- # Speed test
212
- big_batch = test_logs * 100
213
- t0 = time.perf_counter()
214
- classify_batch(big_batch)
215
- elapsed = time.perf_counter() - t0
216
- print(f'\nSpeed: {len(big_batch)/elapsed:.0f} logs/s ({elapsed*1000/len(big_batch):.1f}ms/log)')