ngocminhta commited on
Commit
617c3f7
·
1 Parent(s): 667fbf3

update model specific pred

Browse files
Files changed (2) hide show
  1. app.py +12 -3
  2. infer.py +62 -0
app.py CHANGED
@@ -8,7 +8,7 @@ from src.text_embedding import TextEmbeddingModel
8
  from src.index import Indexer
9
  import os
10
  import pickle
11
- from infer import infer_3_class
12
  import uvicorn
13
  from datasets import disable_caching
14
  disable_caching()
@@ -40,7 +40,7 @@ def load_pkl(path):
40
 
41
  @app.on_event("startup")
42
  def load_model_resources():
43
- global model, tokenizer, index, label_dict, is_mixed_dict
44
 
45
  model = TextEmbeddingModel(opt.model_name)
46
  tokenizer=model.tokenizer
@@ -49,6 +49,7 @@ def load_model_resources():
49
  index.deserialize_from(opt.database_path)
50
  label_dict=load_pkl(os.path.join(opt.database_path,'label_dict.pkl'))
51
  is_mixed_dict=load_pkl(os.path.join(opt.database_path,'is_mixed_dict.pkl'))
 
52
 
53
 
54
  @app.route('/predict', methods=['POST'])
@@ -67,7 +68,15 @@ async def predict(request: Request):
67
  K=21)
68
  return JSONResponse(content={"results": results})
69
  elif mode == "advanced":
70
- return 0
 
 
 
 
 
 
 
 
71
 
72
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
73
 
 
8
  from src.index import Indexer
9
  import os
10
  import pickle
11
+ from infer import infer_3_class, infer_model_specific
12
  import uvicorn
13
  from datasets import disable_caching
14
  disable_caching()
 
40
 
41
  @app.on_event("startup")
42
  def load_model_resources():
43
+ global model, tokenizer, index, label_dict, is_mixed_dict, write_model_dict
44
 
45
  model = TextEmbeddingModel(opt.model_name)
46
  tokenizer=model.tokenizer
 
49
  index.deserialize_from(opt.database_path)
50
  label_dict=load_pkl(os.path.join(opt.database_path,'label_dict.pkl'))
51
  is_mixed_dict=load_pkl(os.path.join(opt.database_path,'is_mixed_dict.pkl'))
52
+ write_model_dict=load_pkl(os.path.join(opt.database_path,'write_model_dict.pkl'))
53
 
54
 
55
  @app.route('/predict', methods=['POST'])
 
68
  K=21)
69
  return JSONResponse(content={"results": results})
70
  elif mode == "advanced":
71
+ results = infer_model_specific(model=model,
72
+ tokenizer=tokenizer,
73
+ index=index,
74
+ label_dict=label_dict,
75
+ is_mixed_dict=is_mixed_dict,
76
+ write_model_dict=write_model_dict,
77
+ text_list=text_list,
78
+ K=9)
79
+ return JSONResponse(content={"results": results})
80
 
81
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
82
 
infer.py CHANGED
@@ -79,4 +79,66 @@ def infer_3_class(model, tokenizer, index, label_dict, is_mixed_dict, text_list,
79
  final[1] = round(fuzzy_cnt[(0,10^3)] / total_score*100,2)
80
  final[2] = round(fuzzy_cnt[(1,1)] / total_score*100,2)
81
  pred.append(final)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  return pred
 
79
  final[1] = round(fuzzy_cnt[(0,10^3)] / total_score*100,2)
80
  final[2] = round(fuzzy_cnt[(1,1)] / total_score*100,2)
81
  pred.append(final)
82
+ return pred
83
+
84
+ def infer_model_specific(model, tokenizer, index, label_dict, is_mixed_dict, write_model_dict, text_list, K):
85
+ encoded_text = tokenizer.batch_encode_plus(
86
+ text_list,
87
+ return_tensors="pt",
88
+ max_length=512,
89
+ padding="max_length",
90
+ truncation=True,
91
+ )
92
+ encoded_text = {k: v for k, v in encoded_text.items()}
93
+ embeddings = model(encoded_text).cpu().detach().numpy()
94
+ top_ids_and_scores = index.search_knn(embeddings, K)
95
+ pred = []
96
+
97
+ for i, (ids, scores) in enumerate(top_ids_and_scores):
98
+ sorted_scores = np.argsort(scores)
99
+ sorted_scores = sorted_scores[::-1]
100
+
101
+ topk_ids = [ids[j] for j in sorted_scores]
102
+ topk_scores = [scores[j] for j in sorted_scores]
103
+ weights = softmax_weights(topk_scores, temperature=0.4)
104
+ candidate_models = [is_mixed_dict[int(_id)] for _id in topk_ids]
105
+ initial_pred = Counter(candidate_models).most_common(1)[0][0]
106
+
107
+ # Initialize fuzzy counts for both 3-class and model-specific predictions
108
+ fuzzy_cnt_3class = {(1,0): 0.0, (0,10^3): 0.0, (1,1): 0.0}
109
+ fuzzy_cnt_model = {
110
+ (1, 0, 0): 0.0, # Human
111
+ (0, 10^3, 1): 0.0, (0, 10^3, 2): 0.0, (0, 10^3, 3): 0.0, (0, 10^3, 4): 0.0, # AI
112
+ (1, 1, 1): 0.0, (1, 1, 2): 0.0, (1, 1, 3): 0.0, (1, 1, 4): 0.0 # Human+AI
113
+ }
114
+
115
+ for id, weight in zip(topk_ids, weights):
116
+ # Update 3-class fuzzy counts
117
+ label_3class = (label_dict[int(id)], is_mixed_dict[int(id)])
118
+ boost_3class = class_type_boost(is_mixed_dict[int(id)], initial_pred)
119
+ fuzzy_cnt_3class[label_3class] += weight * boost_3class
120
+
121
+ # Update model-specific fuzzy counts
122
+ label_model = (label_dict[int(id)], is_mixed_dict[int(id)], write_model_dict[int(id)])
123
+ boost_model = class_type_boost(is_mixed_dict[int(id)], initial_pred)
124
+ fuzzy_cnt_model[label_model] += weight * boost_model
125
+
126
+ # Calculate 3-class probabilities
127
+ total_score_3class = sum(fuzzy_cnt_3class.values())
128
+ final_3class = {
129
+ 0: round(fuzzy_cnt_3class[(1,0)] / total_score_3class * 100, 2),
130
+ 1: round(fuzzy_cnt_3class[(0,10^3)] / total_score_3class * 100, 2),
131
+ 2: round(fuzzy_cnt_3class[(1,1)] / total_score_3class * 100, 2)
132
+ }
133
+
134
+ # Get model-specific prediction
135
+ final_model = max(fuzzy_cnt_model, key=fuzzy_cnt_model.get)
136
+
137
+ # Combine both predictions
138
+ final = {
139
+ "score": final_3class,
140
+ "model": final_model
141
+ }
142
+ pred.append(final)
143
+
144
  return pred