ModelSmith-AI / backend /core /strategy_reasoner.py
ACA050's picture
Upload 79 files
a309487 verified
class StrategyReasoner:
def decide(self, dataset_info, problem_type):
strategy = {}
risks = []
score = 0.0
if dataset_info.get("small_data"):
risks.append("small_dataset")
score += 0.1
if dataset_info.get("high_dimensional"):
risks.append("high_dimensionality")
score += 0.1
if dataset_info.get("imbalance"):
risks.append("class_imbalance")
score += 0.2
if dataset_info.get("sparse_data"):
risks.append("high_missingness")
score += 0.2
if problem_type == "classification":
if "small_dataset" in risks:
model_family = "tree_ensemble"
reason = "Small datasets benefit from simpler models"
elif "high_dimensionality" in risks:
model_family = "tree_ensemble"
reason = "Tree ensembles handle high-dimensional data better"
else:
model_family = "tree_ensemble"
reason = "Tree ensembles handle complexity well"
elif problem_type == "regression":
if "high_dimensionality" in risks:
model_family = "tree_ensemble"
reason = "Tree ensembles handle high-dimensional data better"
else:
model_family = "linear_or_tree"
reason = "Balances interpretability and accuracy"
elif problem_type == "nlp":
model_family = "transformer"
reason = "Transformers best capture language semantics"
strategy["model_family"] = model_family
strategy["reason"] = reason
strategy["risks"] = risks
strategy["confidence"] = round(1 - min(score, 0.9), 2)
return strategy
def explain_strategy(self, strategy):
explanation = f"Selected {strategy['model_family']} models because: {strategy['reason']}."
if strategy.get("risks"):
explanation += f" Identified risks: {', '.join(strategy['risks'])}."
return explanation
def explain_tradeoffs(self, strategy):
explanation = f"Chose {strategy['model_family']} due to: {strategy['reason']}."
if strategy.get("risks"):
explanation += f" Risks detected: {', '.join(strategy['risks'])}."
explanation += f" Confidence score: {strategy.get('confidence')}."
return explanation