sumoy47 commited on
Commit
7e3db74
·
verified ·
1 Parent(s): 666b4cd

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +24 -15
main.py CHANGED
@@ -11,7 +11,6 @@ import os
11
 
12
  app = FastAPI(title="MedGuard API")
13
 
14
- # --- CORS CONFIGURATION ---
15
  app.add_middleware(
16
  CORSMiddleware,
17
  allow_origins=["*"],
@@ -20,7 +19,6 @@ app.add_middleware(
20
  allow_headers=["*"],
21
  )
22
 
23
- # --- CONFIGURATION ---
24
  MODEL_PATH = "./model"
25
  DEVICE = "cpu"
26
 
@@ -28,20 +26,31 @@ print(f"🔄 Loading Model from {MODEL_PATH}...")
28
  model = None
29
  tokenizer = None
30
 
 
 
 
 
 
31
  try:
32
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
33
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
34
  model.to(DEVICE)
35
  model.eval()
36
- print("✅ Model Loaded Successfully!")
 
 
 
 
 
 
 
 
37
  except Exception as e:
38
  print(f"❌ Error loading local model: {e}")
39
- # Fallback
40
  MODEL_NAME = "csebuetnlp/banglabert"
41
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
42
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
43
 
44
- # --- DATA MODELS ---
45
  class QueryRequest(BaseModel):
46
  genre: str = ""
47
  prompt: str = ""
@@ -53,8 +62,6 @@ class PredictionResponse(BaseModel):
53
  probs: dict
54
  explanation: list = None
55
 
56
- LABELS = ["Highly Relevant", "Partially Relevant", "Not Relevant"]
57
-
58
  def predict_proba_lime(texts):
59
  inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
60
  with torch.no_grad():
@@ -63,7 +70,7 @@ def predict_proba_lime(texts):
63
 
64
  @app.get("/")
65
  def health_check():
66
- return {"status": "active", "model": "MedGuard v2.1 (Space Concatenation)"}
67
 
68
  @app.post("/predict", response_model=PredictionResponse)
69
  def predict(request: QueryRequest):
@@ -71,15 +78,12 @@ def predict(request: QueryRequest):
71
  raise HTTPException(status_code=503, detail="Model not loaded")
72
 
73
  try:
74
- # --- CRITICAL FIX ---
75
- # We use simple SPACE concatenation to match standard training dataframe practices.
76
- # No [SEP] tokens, just "Genre Prompt Response"
77
  parts = [part for part in [request.genre, request.prompt, request.text] if part]
78
  full_input = " ".join(parts)
79
 
80
- print(f"📥 Analyzing: {full_input[:100]}...")
81
 
82
- # 1. PREDICT
83
  inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=128).to(DEVICE)
84
  with torch.no_grad():
85
  outputs = model(**inputs)
@@ -87,7 +91,12 @@ def predict(request: QueryRequest):
87
 
88
  pred_idx = np.argmax(probs)
89
 
90
- # 2. EXPLAIN (LIME)
 
 
 
 
 
91
  explainer = LimeTextExplainer(
92
  class_names=LABELS,
93
  split_expression=lambda x: x.split()
@@ -103,7 +112,7 @@ def predict(request: QueryRequest):
103
  lime_features = exp.as_list(label=pred_idx)
104
 
105
  return {
106
- "label": LABELS[pred_idx],
107
  "confidence": round(float(probs[pred_idx]) * 100, 2),
108
  "probs": {l: round(float(p), 4) for l, p in zip(LABELS, probs)},
109
  "explanation": lime_features
 
11
 
12
  app = FastAPI(title="MedGuard API")
13
 
 
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
 
19
  allow_headers=["*"],
20
  )
21
 
 
22
  MODEL_PATH = "./model"
23
  DEVICE = "cpu"
24
 
 
26
  model = None
27
  tokenizer = None
28
 
29
+ # --- CRITICAL FIX: MATCH TRAINING LABEL MAP ---
30
+ # Training Map: {'Not Relevant': 0, 'Partially Relevant': 1, 'Highly Relevant': 2}
31
+ # This list MUST follow the index order: [Index 0, Index 1, Index 2]
32
+ LABELS = ["Not Relevant", "Partially Relevant", "Highly Relevant"]
33
+
34
  try:
35
  tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
36
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
37
  model.to(DEVICE)
38
  model.eval()
39
+
40
+ # Validation check (Optional but good)
41
+ if model.config.id2label:
42
+ print(f"ℹ️ Model config labels: {model.config.id2label}")
43
+ # We enforce our manual list because sometimes configs get messed up during saving
44
+ # but you should visually verify if this print matches our LABELS list
45
+
46
+ print(f"✅ Model Loaded! Label Mapping: {LABELS}")
47
+
48
  except Exception as e:
49
  print(f"❌ Error loading local model: {e}")
 
50
  MODEL_NAME = "csebuetnlp/banglabert"
51
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
52
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)
53
 
 
54
  class QueryRequest(BaseModel):
55
  genre: str = ""
56
  prompt: str = ""
 
62
  probs: dict
63
  explanation: list = None
64
 
 
 
65
  def predict_proba_lime(texts):
66
  inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=128).to(DEVICE)
67
  with torch.no_grad():
 
70
 
71
  @app.get("/")
72
  def health_check():
73
+ return {"status": "active", "model": "MedGuard v2.3 (Fixed Labels)"}
74
 
75
  @app.post("/predict", response_model=PredictionResponse)
76
  def predict(request: QueryRequest):
 
78
  raise HTTPException(status_code=503, detail="Model not loaded")
79
 
80
  try:
81
+ # Use simple space concatenation
 
 
82
  parts = [part for part in [request.genre, request.prompt, request.text] if part]
83
  full_input = " ".join(parts)
84
 
85
+ print(f"📥 Analyzing: {full_input[:50]}...")
86
 
 
87
  inputs = tokenizer(full_input, return_tensors="pt", truncation=True, max_length=128).to(DEVICE)
88
  with torch.no_grad():
89
  outputs = model(**inputs)
 
91
 
92
  pred_idx = np.argmax(probs)
93
 
94
+ # Ensure index is valid
95
+ if pred_idx >= len(LABELS):
96
+ label_str = "Unknown"
97
+ else:
98
+ label_str = LABELS[pred_idx]
99
+
100
  explainer = LimeTextExplainer(
101
  class_names=LABELS,
102
  split_expression=lambda x: x.split()
 
112
  lime_features = exp.as_list(label=pred_idx)
113
 
114
  return {
115
+ "label": label_str,
116
  "confidence": round(float(probs[pred_idx]) * 100, 2),
117
  "probs": {l: round(float(p), 4) for l, p in zip(LABELS, probs)},
118
  "explanation": lime_features