Spaces:
Runtime error
Runtime error
| import json | |
| from tqdm import tqdm | |
| import config | |
| from api_wrappers import hf_data_loader | |
| from generation_steps import synthetic_forward | |
| def transform(df): | |
| print("Generating data for labeling:") | |
| synthetic_forward.print_config() | |
| tqdm.pandas() | |
| manual_df = hf_data_loader.load_raw_rewriting_as_pandas() | |
| manual_df = manual_df.sample(frac=1, random_state=config.RANDOM_STATE).set_index(["hash", "repo"])[ | |
| ["commit_msg_start", "commit_msg_end"] | |
| ] | |
| manual_df = manual_df[~manual_df.index.duplicated(keep="first")] | |
| def get_is_manually_rewritten(row): | |
| commit_id = (row["hash"], row["repo"]) | |
| return commit_id in manual_df.index | |
| result = df | |
| result["manual_sample"] = result.progress_apply(get_is_manually_rewritten, axis=1) | |
| def get_prediction_message(row): | |
| commit_id = (row["hash"], row["repo"]) | |
| if row["manual_sample"]: | |
| return manual_df.loc[commit_id]["commit_msg_start"] | |
| return row["prediction"] | |
| def get_enhanced_message(row): | |
| commit_id = (row["hash"], row["repo"]) | |
| if row["manual_sample"]: | |
| return manual_df.loc[commit_id]["commit_msg_end"] | |
| return synthetic_forward.generate_end_msg(start_msg=row["prediction"], diff=row["mods"]) | |
| result["enhanced"] = result.progress_apply(get_enhanced_message, axis=1) | |
| result["prediction"] = result.progress_apply(get_prediction_message, axis=1) | |
| result["mods"] = result["mods"].progress_apply(json.dumps) | |
| result.to_csv(config.DATA_FOR_LABELING_ARTIFACT) | |
| print("Done") | |
| return result | |
| def main(): | |
| synthetic_forward.GENERATION_ATTEMPTS = 3 | |
| df = hf_data_loader.load_full_commit_with_predictions_as_pandas() | |
| transform(df) | |
| if __name__ == "__main__": | |
| main() | |