SamanthaStorm commited on
Commit
2a840d3
·
verified ·
1 Parent(s): 879a93f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -45
app.py CHANGED
@@ -66,12 +66,16 @@ class FallacyFinder:
66
  try:
67
  logger.info("Loading model: SamanthaStorm/fallacyfinder")
68
  self.tokenizer = AutoTokenizer.from_pretrained("SamanthaStorm/fallacyfinder")
69
- self.model = AutoModelForSequenceClassification.from_pretrained("SamanthaStorm/fallacyfinder")
 
 
 
70
  self.use_model = True
71
  logger.info("✅ Model loaded successfully!")
72
  except Exception as e:
73
  logger.error(f"❌ Error loading model: {e}")
74
- logger.info("Falling back to rule-based approach")
 
75
 
76
  def predict_with_rules(self, text):
77
  """Rule-based fallacy detection for when model isn't available"""
@@ -206,54 +210,52 @@ class FallacyFinder:
206
  return 'no_fallacy', 0.60, ["no_specific_patterns"]
207
 
208
  def predict_fallacy(self, text):
209
- """Main prediction function"""
210
  if not text.strip():
211
  return None, 0, "Please enter a message to analyze.", []
212
 
213
  logger.info(f"ANALYZING: '{text[:100]}{'...' if len(text) > 100 else ''}'")
214
 
215
- if self.use_model and self.model is not None:
216
- # Use trained model
217
- try:
218
- logger.info("Using trained model for prediction")
219
- inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
220
-
221
- with torch.no_grad():
222
- outputs = self.model(**inputs)
223
- predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
224
- predicted_class_id = predictions.argmax().item()
225
- confidence = predictions.max().item()
226
-
227
- # Log all prediction scores
228
- label_keys = list(self.fallacy_labels.keys())
229
- prediction_scores = {}
230
- for i, label in enumerate(label_keys):
231
- score = predictions[0][i].item()
232
- prediction_scores[label] = f"{score:.3f}"
233
-
234
- logger.info(f"MODEL PREDICTIONS: {prediction_scores}")
235
-
236
- # Map to fallacy label
237
- predicted_label = label_keys[predicted_class_id]
238
- logger.info(f"MODEL RESULT: {predicted_label} (confidence: {confidence:.3f})")
239
-
240
- patterns_detected = [f"model_prediction: top_3: {sorted(prediction_scores.items(), key=lambda x: float(x[1]), reverse=True)[:3]}"]
241
-
242
- except Exception as e:
243
- logger.error(f"Model prediction failed: {e}")
244
- logger.info("Falling back to rule-based approach")
245
- predicted_label, confidence, patterns_detected = self.predict_with_rules(text)
246
- else:
247
- # Use rule-based approach
248
- logger.info("Using rule-based approach")
249
- predicted_label, confidence, patterns_detected = self.predict_with_rules(text)
250
-
251
- fallacy_name = self.fallacy_labels[predicted_label]
252
- description = self.fallacy_descriptions[predicted_label]
253
 
254
- logger.info(f"FINAL RESULT: {predicted_label} ({fallacy_name}) - confidence: {confidence}")
255
-
256
- return predicted_label, confidence, fallacy_name, description, patterns_detected
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
  def analyze_message(self, message):
259
  """Analyze a message and return formatted results"""
@@ -315,7 +317,8 @@ finder = FallacyFinder()
315
  logger.info("Fallacy Finder initialized successfully")
316
 
317
  def analyze_fallacy(message):
318
- return finder.analyze_message(message)
 
319
 
320
  # Create the Gradio interface
321
  with gr.Blocks(
 
66
  try:
67
  logger.info("Loading model: SamanthaStorm/fallacyfinder")
68
  self.tokenizer = AutoTokenizer.from_pretrained("SamanthaStorm/fallacyfinder")
69
+ self.model = AutoModelForSequenceClassification.from_pretrained(
70
+ "SamanthaStorm/fallacyfinder",
71
+ num_labels=16 # Specify 16 fallacy classes
72
+ )
73
  self.use_model = True
74
  logger.info("✅ Model loaded successfully!")
75
  except Exception as e:
76
  logger.error(f"❌ Error loading model: {e}")
77
+ logger.error("Model loading failed - cannot continue without trained model")
78
+ raise e # Stop execution if model fails to load
79
 
80
  def predict_with_rules(self, text):
81
  """Rule-based fallacy detection for when model isn't available"""
 
210
  return 'no_fallacy', 0.60, ["no_specific_patterns"]
211
 
212
  def predict_fallacy(self, text):
213
+ """Main prediction function using trained model only"""
214
  if not text.strip():
215
  return None, 0, "Please enter a message to analyze.", []
216
 
217
  logger.info(f"ANALYZING: '{text[:100]}{'...' if len(text) > 100 else ''}'")
218
 
219
+ if not self.use_model or self.model is None:
220
+ logger.error("Trained model not available - cannot analyze")
221
+ return None, 0, "Model not loaded", []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ # Use trained model
224
+ try:
225
+ logger.info("Using trained model for prediction")
226
+ inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
227
+
228
+ with torch.no_grad():
229
+ outputs = self.model(**inputs)
230
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
231
+ predicted_class_id = predictions.argmax().item()
232
+ confidence = predictions.max().item()
233
+
234
+ # Log all prediction scores
235
+ label_keys = list(self.fallacy_labels.keys())
236
+ prediction_scores = {}
237
+ for i, label in enumerate(label_keys):
238
+ score = predictions[0][i].item()
239
+ prediction_scores[label] = f"{score:.3f}"
240
+
241
+ logger.info(f"MODEL PREDICTIONS: {prediction_scores}")
242
+
243
+ # Map to fallacy label
244
+ predicted_label = label_keys[predicted_class_id]
245
+ logger.info(f"MODEL RESULT: {predicted_label} (confidence: {confidence:.3f})")
246
+
247
+ patterns_detected = [f"model_prediction: top_3: {sorted(prediction_scores.items(), key=lambda x: float(x[1]), reverse=True)[:3]}"]
248
+
249
+ fallacy_name = self.fallacy_labels[predicted_label]
250
+ description = self.fallacy_descriptions[predicted_label]
251
+
252
+ logger.info(f"FINAL RESULT: {predicted_label} ({fallacy_name}) - confidence: {confidence}")
253
+
254
+ return predicted_label, confidence, fallacy_name, description, patterns_detected
255
+
256
+ except Exception as e:
257
+ logger.error(f"Model prediction failed: {e}")
258
+ return None, 0, "Prediction failed", []
259
 
260
  def analyze_message(self, message):
261
  """Analyze a message and return formatted results"""
 
317
  logger.info("Fallacy Finder initialized successfully")
318
 
319
  def analyze_fallacy(message):
320
+ result, explanation, _ = finder.analyze_message(message) # Ignore 3rd return value
321
+ return result, explanation
322
 
323
  # Create the Gradio interface
324
  with gr.Blocks(