abhinavvvvv commited on
Commit
1e4a1ed
·
1 Parent(s): aa3f563

solved batch problem

Browse files
Files changed (1) hide show
  1. app.py +54 -40
app.py CHANGED
@@ -2,25 +2,17 @@ import pickle
2
  import os
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
 
5
 
6
- # -----------------------------
7
- # FastAPI initialization
8
- # -----------------------------
9
  app = FastAPI(
10
  title="RTL Log Severity Classifier",
11
- description="Machine learning API that predicts severity of RTL verification logs.",
12
- version="1.0"
13
  )
14
 
15
- # -----------------------------
16
- # Model paths
17
- # -----------------------------
18
  VECTORIZER_PATH = "vectorizer.pkl"
19
  MODEL_PATH = "severity_model.pkl"
20
 
21
- # -----------------------------
22
- # Severity mapping
23
- # -----------------------------
24
  REVERSE_MAP = {
25
  0: "INFO",
26
  1: "WARNING",
@@ -28,57 +20,79 @@ REVERSE_MAP = {
28
  3: "CRITICAL"
29
  }
30
 
31
- # -----------------------------
32
- # Load artifacts safely
33
- # -----------------------------
34
- if not os.path.exists(VECTORIZER_PATH):
35
- raise RuntimeError("vectorizer.pkl not found")
36
-
37
- if not os.path.exists(MODEL_PATH):
38
- raise RuntimeError("severity_model.pkl not found")
39
-
40
  with open(VECTORIZER_PATH, "rb") as f:
41
  vectorizer = pickle.load(f)
42
 
43
  with open(MODEL_PATH, "rb") as f:
44
  model = pickle.load(f)
45
 
46
- # -----------------------------
47
- # Request schema
48
- # -----------------------------
49
- class LogRequest(BaseModel):
50
  module: str
51
  message: str
52
 
53
 
54
- # -----------------------------
55
- # Health check
56
- # -----------------------------
 
 
 
57
  @app.get("/")
58
- def health_check():
59
  return {
60
  "status": "running",
61
  "model": "RTL Severity Classifier",
62
- "classes": ["INFO", "WARNING", "ERROR", "CRITICAL"]
63
  }
64
 
65
 
66
- # -----------------------------
67
- # Prediction endpoint
68
- # -----------------------------
69
  @app.post("/predict")
70
- def predict_severity(request: LogRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- text = request.module + " " + request.message
73
 
74
- vector = vectorizer.transform([text])
75
 
76
- pred = model.predict(vector)[0]
77
 
78
- severity = REVERSE_MAP[pred]
 
 
 
 
79
 
80
  return {
81
- "module": request.module,
82
- "message": request.message,
83
- "predicted_severity": severity
84
  }
 
2
  import os
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
+ from typing import List
6
 
 
 
 
7
  app = FastAPI(
8
  title="RTL Log Severity Classifier",
9
+ description="Batch severity prediction for RTL verification logs",
10
+ version="1.1"
11
  )
12
 
 
 
 
13
  VECTORIZER_PATH = "vectorizer.pkl"
14
  MODEL_PATH = "severity_model.pkl"
15
 
 
 
 
16
  REVERSE_MAP = {
17
  0: "INFO",
18
  1: "WARNING",
 
20
  3: "CRITICAL"
21
  }
22
 
23
+ # Load artifacts
 
 
 
 
 
 
 
 
24
  with open(VECTORIZER_PATH, "rb") as f:
25
  vectorizer = pickle.load(f)
26
 
27
  with open(MODEL_PATH, "rb") as f:
28
  model = pickle.load(f)
29
 
30
+
31
+ # ---------- Request Schemas ----------
32
+
33
+ class LogItem(BaseModel):
34
  module: str
35
  message: str
36
 
37
 
38
+ class BatchRequest(BaseModel):
39
+ logs: List[LogItem]
40
+
41
+
42
+ # ---------- Health ----------
43
+
44
  @app.get("/")
45
+ def health():
46
  return {
47
  "status": "running",
48
  "model": "RTL Severity Classifier",
49
+ "batch_support": True
50
  }
51
 
52
 
53
+ # ---------- Single Prediction ----------
54
+
 
55
  @app.post("/predict")
56
+ def predict(log: LogItem):
57
+
58
+ text = log.module + " " + log.message
59
+
60
+ vec = vectorizer.transform([text])
61
+
62
+ pred = model.predict(vec)[0]
63
+
64
+ return {
65
+ "module": log.module,
66
+ "message": log.message,
67
+ "predicted_severity": REVERSE_MAP[pred]
68
+ }
69
+
70
+
71
+ # ---------- Batch Prediction ----------
72
+
73
+ @app.post("/predict_batch")
74
+ def predict_batch(request: BatchRequest):
75
+
76
+ texts = [
77
+ log.module + " " + log.message
78
+ for log in request.logs
79
+ ]
80
+
81
+ vectors = vectorizer.transform(texts)
82
 
83
+ preds = model.predict(vectors)
84
 
85
+ results = []
86
 
87
+ for i, p in enumerate(preds):
88
 
89
+ results.append({
90
+ "module": request.logs[i].module,
91
+ "message": request.logs[i].message,
92
+ "predicted_severity": REVERSE_MAP[int(p)]
93
+ })
94
 
95
  return {
96
+ "count": len(results),
97
+ "results": results
 
98
  }