AIFinder / inference.py
CompactAI's picture
Upload 8 files
17ef86f verified
raw
history blame
1.64 kB
"""
AIFinder Inference Module
Load the trained model and predict AI provider
"""
import joblib
import numpy as np
from config import MODEL_DIR
class AIFinder:
def __init__(self, model_dir=MODEL_DIR):
self.models = joblib.load(f"{model_dir}/rf_4provider.joblib")
self.pipeline = joblib.load(f"{model_dir}/pipeline_4provider.joblib")
self.le = joblib.load(f"{model_dir}/enc_4provider.joblib")
def predict(self, text):
"""Predict the provider for a given text"""
X = self.pipeline.transform([text])
proba = np.mean([m.predict_proba(X) for m in self.models], axis=0)
pred_idx = np.argmax(proba[0])
return self.le.classes_[pred_idx]
def predict_proba(self, text):
"""Get prediction probabilities"""
X = self.pipeline.transform([text])
proba = np.mean([m.predict_proba(X) for m in self.models], axis=0)
return dict(zip(self.le.classes_, proba[0]))
def predict_with_confidence(self, text):
"""Predict with confidence score"""
proba = self.predict_proba(text)
provider = max(proba, key=proba.get)
confidence = proba[provider]
return provider, confidence
if __name__ == "__main__":
finder = AIFinder()
test_texts = [
"AI is like a really smart robot helper.",
"Yes, coding is one of my stronger skills!",
"A lot, depending on what you need.",
]
for text in test_texts:
provider, conf = finder.predict_with_confidence(text)
print(f"Text: {text[:50]}...")
print(f"Provider: {provider} (confidence: {conf:.2f})")
print()