Spaces:
Sleeping
Sleeping
| class StrategyReasoner: | |
| def decide(self, dataset_info, problem_type): | |
| strategy = {} | |
| risks = [] | |
| score = 0.0 | |
| if dataset_info.get("small_data"): | |
| risks.append("small_dataset") | |
| score += 0.1 | |
| if dataset_info.get("high_dimensional"): | |
| risks.append("high_dimensionality") | |
| score += 0.1 | |
| if dataset_info.get("imbalance"): | |
| risks.append("class_imbalance") | |
| score += 0.2 | |
| if dataset_info.get("sparse_data"): | |
| risks.append("high_missingness") | |
| score += 0.2 | |
| if problem_type == "classification": | |
| if "small_dataset" in risks: | |
| model_family = "tree_ensemble" | |
| reason = "Small datasets benefit from simpler models" | |
| elif "high_dimensionality" in risks: | |
| model_family = "tree_ensemble" | |
| reason = "Tree ensembles handle high-dimensional data better" | |
| else: | |
| model_family = "tree_ensemble" | |
| reason = "Tree ensembles handle complexity well" | |
| elif problem_type == "regression": | |
| if "high_dimensionality" in risks: | |
| model_family = "tree_ensemble" | |
| reason = "Tree ensembles handle high-dimensional data better" | |
| else: | |
| model_family = "linear_or_tree" | |
| reason = "Balances interpretability and accuracy" | |
| elif problem_type == "nlp": | |
| model_family = "transformer" | |
| reason = "Transformers best capture language semantics" | |
| strategy["model_family"] = model_family | |
| strategy["reason"] = reason | |
| strategy["risks"] = risks | |
| strategy["confidence"] = round(1 - min(score, 0.9), 2) | |
| return strategy | |
| def explain_strategy(self, strategy): | |
| explanation = f"Selected {strategy['model_family']} models because: {strategy['reason']}." | |
| if strategy.get("risks"): | |
| explanation += f" Identified risks: {', '.join(strategy['risks'])}." | |
| return explanation | |
| def explain_tradeoffs(self, strategy): | |
| explanation = f"Chose {strategy['model_family']} due to: {strategy['reason']}." | |
| if strategy.get("risks"): | |
| explanation += f" Risks detected: {', '.join(strategy['risks'])}." | |
| explanation += f" Confidence score: {strategy.get('confidence')}." | |
| return explanation | |