Syamchand commited on
Commit
df560df
·
verified ·
1 Parent(s): e91ca6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -10
app.py CHANGED
@@ -247,21 +247,28 @@ def container_memory():
247
 
248
  @app.post("/predict/contracts_clauses", response_model=ClassificationResult)
249
  def predict_contracts_clauses(req: TextRequest):
 
250
  # The SetFit model predicts labels directly (no integer conversion needed)
251
- preds = models["contracts_clauses"].predict([req.text]) # Ensure using predict()
252
  label = preds[0] # Already a string like 'terms'
253
 
254
  # Try to get a confidence score using predict_proba if available
255
  score = 1.0
256
- if hasattr(models["contracts_clauses"], "predict_proba"):
257
- probs = models["contracts_clauses"].predict_proba([req.text])[0]
258
- # Get the index of the predicted label and its probability
259
- # Note: SetFit often stores labels in .model.labels
260
- if hasattr(models["contracts_clauses"].model, "labels"):
261
- idx = list(models["contracts_clauses"].model.labels).index(label)
262
- score = probs[idx]
263
- else:
264
- score = max(probs)
 
 
 
 
 
 
265
 
266
  return ClassificationResult(label=label, score=round(float(score), 4))
267
 
 
247
 
248
  @app.post("/predict/contracts_clauses", response_model=ClassificationResult)
249
  def predict_contracts_clauses(req: TextRequest):
250
+ model = models["contracts_clauses"]
251
  # The SetFit model predicts labels directly (no integer conversion needed)
252
+ preds = model.predict([req.text])
253
  label = preds[0] # Already a string like 'terms'
254
 
255
  # Try to get a confidence score using predict_proba if available
256
  score = 1.0
257
+ if hasattr(model, "predict_proba"):
258
+ try:
259
+ probs = model.predict_proba([req.text])[0]
260
+ # model.labels stores the label strings in the order expected by predict_proba
261
+ if hasattr(model, "labels") and model.labels is not None:
262
+ # Find the index of the predicted label
263
+ if label in model.labels:
264
+ idx = model.labels.index(label)
265
+ score = probs[idx]
266
+ else:
267
+ score = max(probs)
268
+ else:
269
+ score = max(probs)
270
+ except Exception:
271
+ score = 1.0
272
 
273
  return ClassificationResult(label=label, score=round(float(score), 4))
274