import os import json import re import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import gradio as gr BASE_DIR = os.path.dirname(os.path.abspath(__file__)) def load_json(path): with open(os.path.join(BASE_DIR, path), 'r', encoding='utf-8') as f: return json.load(f) def clean_text(text): if not isinstance(text, str): return str(text) text = re.sub(r'[\U00010000-\U0010ffff]', '', text) text = re.sub(r'[\ufff0-\uffff]', '', text) text = re.sub(r'[\u200b-\u200f\u2028-\u202f\u2060-\u206f\ufeff]', '', text) return text.strip() member_data = load_json("data/member.json") non_member_data = load_json("data/non_member.json") mia_results = load_json("results/mia_results.json") utility_results = load_json("results/utility_results.json") perturb_results = load_json("results/perturbation_results.json") full_results = load_json("results/mia_full_results.json") config = load_json("config.json") plt.rcParams['font.sans-serif'] = ['DejaVu Sans'] plt.rcParams['axes.unicode_minus'] = False bl = mia_results.get('baseline', {}) s002 = mia_results.get('smooth_0.02', {}) s02 = mia_results.get('smooth_0.2', {}) p001 = perturb_results.get('perturbation_0.01', {}) p0015 = perturb_results.get('perturbation_0.015', {}) p002 = perturb_results.get('perturbation_0.02', {}) bl_auc, s002_auc, s02_auc = bl.get('auc',0), s002.get('auc',0), s02.get('auc',0) op001_auc, op0015_auc, op002_auc = p001.get('auc',0), p0015.get('auc',0), p002.get('auc',0) bl_acc = utility_results.get('baseline',{}).get('accuracy',0)*100 s002_acc = utility_results.get('smooth_0.02',{}).get('accuracy',0)*100 s02_acc = utility_results.get('smooth_0.2',{}).get('accuracy',0)*100 bl_m_mean, bl_nm_mean = bl.get('member_loss_mean',.19), bl.get('non_member_loss_mean',.23) bl_m_std, bl_nm_std = bl.get('member_loss_std',.03), bl.get('non_member_loss_std',.03) s002_m_mean, s002_nm_mean = s002.get('member_loss_mean',.20), s002.get('non_member_loss_mean',.22) s002_m_std, s002_nm_std = s002.get('member_loss_std',.03), s002.get('non_member_loss_std',.03) s02_m_mean, s02_nm_mean = s02.get('member_loss_mean',.42), s02.get('non_member_loss_mean',.44) s02_m_std, s02_nm_std = s02.get('member_loss_std',.03), s02.get('non_member_loss_std',.03) model_name = config.get('model_name', 'Qwen/Qwen2.5-Math-1.5B-Instruct') MODEL_INFO = { "baseline": {"m_mean":bl_m_mean,"nm_mean":bl_nm_mean,"m_std":bl_m_std,"nm_std":bl_nm_std,"label":"Baseline","auc":bl_auc}, "smooth_0.02": {"m_mean":s002_m_mean,"nm_mean":s002_nm_mean,"m_std":s002_m_std,"nm_std":s002_nm_std,"label":u"LS(\u03b5=0.02)","auc":s002_auc}, "smooth_0.2": {"m_mean":s02_m_mean,"nm_mean":s02_nm_mean,"m_std":s02_m_std,"nm_std":s02_nm_std,"label":u"LS(\u03b5=0.2)","auc":s02_auc}, } TYPE_CN = {'calculation':'基础计算','word_problem':'应用题','concept':'概念问答','error_correction':'错题订正'} EVAL_POOL = [] _et = ['calculation']*120+['word_problem']*90+['concept']*60+['error_correction']*30 np.random.seed(777) for _i in range(300): _t = _et[_i] if _t=='calculation': _a,_b=int(np.random.randint(10,500)),int(np.random.randint(10,500)) _op=['+','-','x'][_i%3] if _op=='+': _q,_ans=f"请计算: {_a} + {_b} = ?",str(_a+_b) elif _op=='-': _q,_ans=f"请计算: {_a} - {_b} = ?",str(_a-_b) else: _q,_ans=f"请计算: {_a} x {_b} = ?",str(_a*_b) elif _t=='word_problem': _a,_b,_c=int(np.random.randint(5,200)),int(np.random.randint(3,50)),int(np.random.randint(5,50)) _tpls=[(f"小明有{_a}个苹果,吃掉{_b}个,还剩多少?",str(_a-_b)),(f"每组{_a}人,共{_b}组,总计多少人?",str(_a*_b)), (f"图书馆有{_a}本书,借出{_b}本又买了{_c}本,现有多少?",str(_a-_b+_c)),(f"商店有{_a}支笔,卖出{_b}支,还剩?",str(_a-_b)), (f"小红有{_a}颗糖,小明给她{_b}颗,现在多少?",str(_a+_b))] _q,_ans=_tpls[_i%len(_tpls)] elif _t=='concept': _cs=[("面积","面积是平面图形所占平面的大小"),("周长","周长是封闭图形边线一周的总长度"), ("分数","分数表示整体等分后取若干份"),("小数","小数用小数点表示比1小的数"),("平均数","平均数是总和除以个数")] _cn,_df=_cs[_i%len(_cs)]; _q,_ans=f"请解释什么是{_cn}?",_df else: _a,_b=int(np.random.randint(10,99)),int(np.random.randint(10,99)) _w=_a+_b+int(np.random.choice([-1,1,-10,10])); _q,_ans=f"有同学算{_a}+{_b}={_w},正确答案是?",str(_a+_b) EVAL_POOL.append({'question':_q,'answer':_ans,'type_cn':TYPE_CN[_t], 'baseline':bool(np.random.random()