Text_classification / api_symptom_checker.py
Tantawi's picture
Upload 13 files
f2a4578 verified
import argparse
import json
import os
from typing import List, Dict, Any
import numpy as np
import xgboost as xgb
from sklearn.preprocessing import LabelEncoder
def load_artifacts(prefix: str):
"""Load the trained model artifacts."""
model_path = f"{prefix}.json"
labels_path = f"{prefix}.labels.npy"
features_path = f"{prefix}.features.txt"
if not (os.path.exists(model_path) and os.path.exists(labels_path) and os.path.exists(features_path)):
raise FileNotFoundError(f"Missing artifacts. Expected: {model_path}, {labels_path}, {features_path}")
model = xgb.XGBClassifier()
model.load_model(model_path)
label_encoder = LabelEncoder()
classes = np.load(labels_path, allow_pickle=True)
label_encoder.classes_ = classes
with open(features_path, "r", encoding="utf-8") as f:
feature_names = [line.strip() for line in f if line.strip()]
return model, label_encoder, feature_names
def build_feature_vector(symptom_names: List[str], selected_symptoms: List[str]) -> np.ndarray:
"""Convert symptom list to feature vector."""
features = np.zeros(len(symptom_names), dtype=float)
name_to_index = {name.lower().strip(): idx for idx, name in enumerate(symptom_names)}
for symptom in selected_symptoms:
key = symptom.lower().strip()
if key in name_to_index:
features[name_to_index[key]] = 1.0
return features.reshape(1, -1)
def predict_symptoms_json(symptoms: List[str], model, label_encoder, feature_names: List[str]) -> Dict[str, Any]:
"""Return predictions in JSON format for API integration."""
if not symptoms:
return {"error": "No symptoms provided"}
# Build feature vector
x = build_feature_vector(feature_names, symptoms)
# Get predictions
proba = model.predict_proba(x)[0]
top3_idx = np.argsort(proba)[-3:][::-1]
# Format results
predictions = []
for rank, idx in enumerate(top3_idx, 1):
disease_name = label_encoder.inverse_transform([idx])[0]
confidence = float(proba[idx])
predictions.append({
"rank": rank,
"disease": disease_name,
"confidence": confidence,
"confidence_percent": round(confidence * 100, 2)
})
return {
"input_symptoms": symptoms,
"primary_diagnosis": predictions[0],
"top_predictions": predictions,
"model_confidence": "high" if predictions[0]["confidence"] > 0.7 else "medium" if predictions[0]["confidence"] > 0.4 else "low"
}
def predict_symptoms_csv(symptoms: List[str], model, label_encoder, feature_names: List[str]) -> str:
"""Return predictions in CSV format."""
if not symptoms:
return "error,No symptoms provided"
x = build_feature_vector(feature_names, symptoms)
proba = model.predict_proba(x)[0]
top3_idx = np.argsort(proba)[-3:][::-1]
csv_lines = ["rank,disease,confidence,confidence_percent"]
for rank, idx in enumerate(top3_idx, 1):
disease_name = label_encoder.inverse_transform([idx])[0]
confidence = proba[idx]
csv_lines.append(f"{rank},{disease_name},{confidence:.4f},{confidence*100:.2f}")
return "\n".join(csv_lines)
def predict_symptoms_simple(symptoms: List[str], model, label_encoder, feature_names: List[str]) -> str:
"""Return simple text format."""
if not symptoms:
return "Error: No symptoms provided"
x = build_feature_vector(feature_names, symptoms)
proba = model.predict_proba(x)[0]
top1_idx = np.argmax(proba)
disease_name = label_encoder.inverse_transform([top1_idx])[0]
confidence = proba[top1_idx]
return f"Diagnosis: {disease_name} (Confidence: {confidence*100:.1f}%)"
def main():
parser = argparse.ArgumentParser(description="API-style symptom checker using saved model")
parser.add_argument("--symptoms", nargs="+", required=True, help="List of symptoms")
parser.add_argument("--format", choices=["json", "csv", "simple"], default="json", help="Output format")
parser.add_argument("--artifacts-prefix", default="symptom_checker/symptom_model", help="Path to model artifacts")
args = parser.parse_args()
try:
# Load the trained model
model, label_encoder, feature_names = load_artifacts(args.artifacts_prefix)
# Get predictions in requested format
if args.format == "json":
result = predict_symptoms_json(args.symptoms, model, label_encoder, feature_names)
print(json.dumps(result, indent=2))
elif args.format == "csv":
result = predict_symptoms_csv(args.symptoms, model, label_encoder, feature_names)
print(result)
elif args.format == "simple":
result = predict_symptoms_simple(args.symptoms, model, label_encoder, feature_names)
print(result)
except Exception as e:
error_result = {"error": str(e), "input_symptoms": args.symptoms}
if args.format == "json":
print(json.dumps(error_result, indent=2))
else:
print(f"Error: {e}")
if __name__ == "__main__":
main()