stevafernandes commited on
Commit
e96c629
Β·
verified Β·
1 Parent(s): 25a20f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +171 -178
app.py CHANGED
@@ -123,24 +123,16 @@ CLASSIFIER_PATH = os.path.join(MODEL_DIR, "classifier.pkl")
123
  SCORER_PATH = os.path.join(MODEL_DIR, "scorer.pkl")
124
  SCALER_PATH = os.path.join(MODEL_DIR, "scaler.pkl")
125
  THRESHOLD_PATH = os.path.join(MODEL_DIR, "thresholds.pkl")
126
- ENSEMBLE_PATH = os.path.join(MODEL_DIR, "ensemble.pkl")
127
 
128
  @st.cache_resource
129
  def load_sentence_transformer():
130
  """Load sentence transformer model"""
131
- models_to_try = [
132
- 'all-MiniLM-L6-v2', # Lightweight and reliable
133
- 'all-mpnet-base-v2' # Good alternative
134
- ]
135
-
136
- for model_name in models_to_try:
137
- try:
138
- model = SentenceTransformer(model_name)
139
- return model, model_name
140
- except:
141
- continue
142
-
143
- return SentenceTransformer('all-MiniLM-L6-v2'), 'all-MiniLM-L6-v2'
144
 
145
  def segment_text(text, embedder):
146
  """Segment text into meaningful chunks"""
@@ -154,7 +146,6 @@ def segment_text(text, embedder):
154
  if len(sentences) < 3:
155
  return [text]
156
 
157
- # Group sentences into segments
158
  segments = []
159
  current_segment = []
160
  for sent in sentences:
@@ -212,16 +203,22 @@ def extract_features(text, embedder, category_focus=None):
212
 
213
  # Get embeddings
214
  try:
215
- embedding = embedder.encode(text, convert_to_tensor=False, normalize_embeddings=True)
 
 
 
216
  except:
217
- embedding = embedder.encode(text)
218
 
219
  # Category similarity
220
  if category_focus and category_focus in CATEGORIES:
221
  category_text = f"{CATEGORIES[category_focus]['description']} {' '.join(CATEGORIES[category_focus]['keywords'][:10])}"
222
  try:
223
- category_embedding = embedder.encode(category_text, normalize_embeddings=True)
224
- similarity = cosine_similarity([embedding], [category_embedding])[0][0]
 
 
 
225
  features.append(similarity * 10)
226
  except:
227
  features.append(0)
@@ -229,7 +226,7 @@ def extract_features(text, embedder, category_focus=None):
229
  features.append(0)
230
 
231
  features = np.array(features, dtype=np.float32)
232
- combined_features = np.concatenate([features, embedding[:256]]) # Limit embedding size
233
 
234
  return combined_features
235
 
@@ -496,179 +493,175 @@ def analyze_statement(text, embedder, scaler, classifiers, scorers, thresholds):
496
 
497
  return segment_results, category_results
498
 
499
- # Main application
500
- def main():
501
- st.title("πŸ₯ Medical School Personal Statement Analyzer")
502
- st.markdown("*AI-powered analysis based on medical school admission rubrics*")
503
- st.markdown("---")
504
-
505
- # Sidebar
506
- with st.sidebar:
507
- st.header("ℹ️ About")
508
- st.markdown("""
509
- This tool analyzes personal statements based on 4 key categories:
510
- - **Spark**: Opening that shows interest in medicine
511
- - **Healthcare Experience**: Clinical/medical experiences
512
- - **Doctor Qualities**: Leadership and character traits
513
- - **Spin**: Connecting experiences to medical career
514
-
515
- Each category is scored 1-4 (Poor to Excellent)
516
- """)
517
-
518
- # Create tabs
519
- tab1, tab2, tab3 = st.tabs(["πŸ“š Train Model", "πŸ“ Analyze Statement", "πŸ“Š View Rubrics"])
520
-
521
- # Train Model Tab
522
- with tab1:
523
- st.header("Train the AI Model")
524
-
525
- if all(os.path.exists(p) for p in [CLASSIFIER_PATH, SCORER_PATH, SCALER_PATH]):
526
- st.success("βœ“ Models already trained. You can analyze statements or retrain.")
527
-
528
- st.markdown("Upload training data files (Excel format with coded excerpts)")
529
-
530
- col1, col2 = st.columns(2)
531
- with col1:
532
- file1 = st.file_uploader("Training File 1", type=['xlsx'], key="file1")
533
- with col2:
534
- file2 = st.file_uploader("Training File 2", type=['xlsx'], key="file2")
535
-
536
- if file1 and file2:
537
- if st.button("Start Training", type="primary"):
538
- try:
539
- # Load data
540
- with st.spinner("Loading training data..."):
541
- df = load_training_data(file1, file2)
542
-
543
- if df.empty:
544
- st.error("No valid training data found.")
545
- return
546
-
547
  st.success(f"βœ“ Loaded {len(df)} training samples")
548
 
549
  # Load embedder
550
  with st.spinner("Loading transformer model..."):
551
  embedder, embedder_name = load_sentence_transformer()
552
 
553
- # Train
554
- scaler, classifiers, scorers, thresholds = train_models(df, embedder)
 
 
 
 
 
 
 
555
 
556
- # Save
557
- save_models(embedder_name, scaler, classifiers, scorers, thresholds)
558
- st.success("βœ“ Training complete! Models saved.")
559
-
560
- except Exception as e:
561
- st.error(f"Training failed: {str(e)}")
562
 
563
- # Analyze Statement Tab
564
- with tab2:
565
- st.header("Analyze Personal Statement")
566
-
567
- if not all(os.path.exists(p) for p in [CLASSIFIER_PATH, SCORER_PATH, SCALER_PATH]):
568
- st.warning("⚠️ Please train the model first (Tab 1)")
569
- return
570
-
571
  # Load models
572
  embedder, scaler, classifiers, scorers, thresholds = load_saved_models()
573
 
574
  if embedder is None:
575
  st.error("Failed to load models. Please retrain.")
576
- return
577
-
578
- # Input method
579
- input_method = st.radio("Choose input method:", ["Paste Text", "Upload File"])
580
-
581
- text_to_analyze = None
582
-
583
- if input_method == "Paste Text":
584
- text_to_analyze = st.text_area(
585
- "Paste your personal statement here:",
586
- height=300,
587
- placeholder="Enter your personal statement..."
588
- )
589
  else:
590
- uploaded_file = st.file_uploader("Upload statement (.txt)", type=['txt'])
591
- if uploaded_file:
592
- text_to_analyze = str(uploaded_file.read(), 'utf-8')
593
- st.success("File uploaded successfully!")
594
-
595
- if text_to_analyze and st.button("Analyze Statement", type="primary"):
596
- with st.spinner("Analyzing..."):
597
- segment_results, category_results = analyze_statement(
598
- text_to_analyze, embedder, scaler, classifiers, scorers, thresholds
599
- )
600
-
601
- # Display results
602
- st.success("βœ“ Analysis complete!")
603
-
604
- # Summary
605
- st.subheader("πŸ“Š Overall Summary")
606
- cols = st.columns(4)
607
 
608
- detected = [cat for cat, res in category_results.items() if res['detected']]
609
 
610
- with cols[0]:
611
- st.metric("Categories Found", f"{len(detected)}/4")
612
- with cols[1]:
613
- if detected:
614
- avg_score = np.mean([category_results[cat]['score'] for cat in detected])
615
- st.metric("Average Score", f"{avg_score:.1f}/4")
616
- else:
617
- st.metric("Average Score", "N/A")
618
- with cols[2]:
619
- st.metric("Total Segments", len(segment_results))
620
- with cols[3]:
621
- quality = "Excellent" if len(detected) == 4 and avg_score >= 3.5 else "Good" if len(detected) >= 3 else "Needs Work"
622
- st.metric("Overall", quality)
623
-
624
- # Category breakdown
625
- st.subheader("πŸ“‹ Category Analysis")
626
- for cat in CATEGORIES.keys():
627
- res = category_results[cat]
628
- if res['detected']:
629
- icon = "βœ…" if res['score'] >= 3 else "⚠️" if res['score'] >= 2 else "❌"
630
- st.write(f"{icon} **{cat}**: Score {res['score']}/4 (Confidence: {res['confidence']:.1%})")
631
- else:
632
- st.write(f"❌ **{cat}**: Not detected")
633
-
634
- # Segment details
635
- st.subheader("πŸ“ Segment Details")
636
- for seg in segment_results:
637
- with st.expander(f"Segment {seg['segment_num']}: {seg['category']}"):
638
- st.write(f"**Score:** {seg['score']}/4" if seg['score'] else "N/A")
639
- st.write(f"**Confidence:** {seg['confidence']:.1%}")
640
- st.write(f"**Text:** {seg['text'][:300]}...")
641
-
642
- # Recommendations
643
- st.subheader("πŸ’‘ Recommendations")
644
- missing = [cat for cat, res in category_results.items() if not res['detected']]
645
- low_score = [cat for cat, res in category_results.items()
646
- if res['detected'] and res['score'] and res['score'] < 3]
647
-
648
- if missing:
649
- st.warning("**Missing Categories:**")
650
- for cat in missing:
651
- st.write(f"β€’ Add content for **{cat}**: {CATEGORIES[cat]['description']}")
652
-
653
- if low_score:
654
- st.info("**Areas to Improve:**")
655
- for cat in low_score:
656
- st.write(f"β€’ Strengthen **{cat}** (current score: {category_results[cat]['score']}/4)")
657
 
658
- if not missing and not low_score:
659
- st.success("Excellent work! All categories present with good scores.")
660
-
661
- # View Rubrics Tab
662
- with tab3:
663
- st.header("Scoring Rubrics")
664
-
665
- for category, info in CATEGORIES.items():
666
- with st.expander(f"**{category}**"):
667
- st.write(f"**Description:** {info['description']}")
668
- st.write("**Scoring Criteria:**")
669
- for score in [4, 3, 2, 1]:
670
- st.write(f"β€’ **Score {score}:** {info['rubric'][score]}")
671
- st.write(f"**Key Terms:** {', '.join(info['keywords'][:8])}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672
 
673
- if __name__ == "__main__":
674
- main()
 
 
 
 
 
 
 
 
 
 
123
  SCORER_PATH = os.path.join(MODEL_DIR, "scorer.pkl")
124
  SCALER_PATH = os.path.join(MODEL_DIR, "scaler.pkl")
125
  THRESHOLD_PATH = os.path.join(MODEL_DIR, "thresholds.pkl")
 
126
 
127
  @st.cache_resource
128
  def load_sentence_transformer():
129
  """Load sentence transformer model"""
130
+ try:
131
+ model = SentenceTransformer('all-MiniLM-L6-v2')
132
+ return model, 'all-MiniLM-L6-v2'
133
+ except:
134
+ st.error("Failed to load sentence transformer model")
135
+ return None, None
 
 
 
 
 
 
 
136
 
137
  def segment_text(text, embedder):
138
  """Segment text into meaningful chunks"""
 
146
  if len(sentences) < 3:
147
  return [text]
148
 
 
149
  segments = []
150
  current_segment = []
151
  for sent in sentences:
 
203
 
204
  # Get embeddings
205
  try:
206
+ embedding = embedder.encode(text, convert_to_tensor=False)
207
+ if hasattr(embedding, 'cpu'):
208
+ embedding = embedding.cpu().numpy()
209
+ embedding = embedding.flatten()[:256] # Limit size
210
  except:
211
+ embedding = np.zeros(256)
212
 
213
  # Category similarity
214
  if category_focus and category_focus in CATEGORIES:
215
  category_text = f"{CATEGORIES[category_focus]['description']} {' '.join(CATEGORIES[category_focus]['keywords'][:10])}"
216
  try:
217
+ category_embedding = embedder.encode(category_text)
218
+ if hasattr(category_embedding, 'cpu'):
219
+ category_embedding = category_embedding.cpu().numpy()
220
+ category_embedding = category_embedding.flatten()
221
+ similarity = cosine_similarity([embedding], [category_embedding[:256]])[0][0]
222
  features.append(similarity * 10)
223
  except:
224
  features.append(0)
 
226
  features.append(0)
227
 
228
  features = np.array(features, dtype=np.float32)
229
+ combined_features = np.concatenate([features, embedding])
230
 
231
  return combined_features
232
 
 
493
 
494
  return segment_results, category_results
495
 
496
+ # Main UI Code
497
+ st.title("πŸ₯ Medical School Personal Statement Analyzer")
498
+ st.markdown("*AI-powered analysis based on medical school admission rubrics*")
499
+ st.markdown("---")
500
+
501
+ # Sidebar
502
+ with st.sidebar:
503
+ st.header("ℹ️ About")
504
+ st.markdown("""
505
+ This tool analyzes personal statements based on 4 key categories:
506
+ - **Spark**: Opening that shows interest in medicine
507
+ - **Healthcare Experience**: Clinical/medical experiences
508
+ - **Doctor Qualities**: Leadership and character traits
509
+ - **Spin**: Connecting experiences to medical career
510
+
511
+ Each category is scored 1-4 (Poor to Excellent)
512
+ """)
513
+
514
+ # Create tabs
515
+ tab1, tab2, tab3 = st.tabs(["πŸ“š Train Model", "πŸ“ Analyze Statement", "πŸ“Š View Rubrics"])
516
+
517
+ # Train Model Tab
518
+ with tab1:
519
+ st.header("Train the AI Model")
520
+
521
+ if all(os.path.exists(p) for p in [CLASSIFIER_PATH, SCORER_PATH, SCALER_PATH]):
522
+ st.success("βœ“ Models already trained. You can analyze statements or retrain.")
523
+
524
+ st.markdown("Upload training data files (Excel format with coded excerpts)")
525
+
526
+ col1, col2 = st.columns(2)
527
+ with col1:
528
+ file1 = st.file_uploader("Training File 1", type=['xlsx'], key="file1")
529
+ with col2:
530
+ file2 = st.file_uploader("Training File 2", type=['xlsx'], key="file2")
531
+
532
+ if file1 and file2:
533
+ if st.button("Start Training", type="primary"):
534
+ try:
535
+ # Load data
536
+ with st.spinner("Loading training data..."):
537
+ df = load_training_data(file1, file2)
538
+
539
+ if df.empty:
540
+ st.error("No valid training data found.")
541
+ else:
 
 
542
  st.success(f"βœ“ Loaded {len(df)} training samples")
543
 
544
  # Load embedder
545
  with st.spinner("Loading transformer model..."):
546
  embedder, embedder_name = load_sentence_transformer()
547
 
548
+ if embedder is not None:
549
+ # Train
550
+ scaler, classifiers, scorers, thresholds = train_models(df, embedder)
551
+
552
+ # Save
553
+ save_models(embedder_name, scaler, classifiers, scorers, thresholds)
554
+ st.success("βœ“ Training complete! Models saved.")
555
+ else:
556
+ st.error("Failed to load transformer model")
557
 
558
+ except Exception as e:
559
+ st.error(f"Training failed: {str(e)}")
560
+
561
+ # Analyze Statement Tab
562
+ with tab2:
563
+ st.header("Analyze Personal Statement")
564
 
565
+ if not all(os.path.exists(p) for p in [CLASSIFIER_PATH, SCORER_PATH, SCALER_PATH]):
566
+ st.warning("⚠️ Please train the model first (Tab 1)")
567
+ else:
 
 
 
 
 
568
  # Load models
569
  embedder, scaler, classifiers, scorers, thresholds = load_saved_models()
570
 
571
  if embedder is None:
572
  st.error("Failed to load models. Please retrain.")
 
 
 
 
 
 
 
 
 
 
 
 
 
573
  else:
574
+ # Input method
575
+ input_method = st.radio("Choose input method:", ["Paste Text", "Upload File"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
 
577
+ text_to_analyze = None
578
 
579
+ if input_method == "Paste Text":
580
+ text_to_analyze = st.text_area(
581
+ "Paste your personal statement here:",
582
+ height=300,
583
+ placeholder="Enter your personal statement..."
584
+ )
585
+ else:
586
+ uploaded_file = st.file_uploader("Upload statement (.txt)", type=['txt'])
587
+ if uploaded_file:
588
+ text_to_analyze = str(uploaded_file.read(), 'utf-8')
589
+ st.success("File uploaded successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
590
 
591
+ if text_to_analyze and st.button("Analyze Statement", type="primary"):
592
+ with st.spinner("Analyzing..."):
593
+ segment_results, category_results = analyze_statement(
594
+ text_to_analyze, embedder, scaler, classifiers, scorers, thresholds
595
+ )
596
+
597
+ # Display results
598
+ st.success("βœ“ Analysis complete!")
599
+
600
+ # Summary
601
+ st.subheader("πŸ“Š Overall Summary")
602
+ cols = st.columns(4)
603
+
604
+ detected = [cat for cat, res in category_results.items() if res['detected']]
605
+
606
+ with cols[0]:
607
+ st.metric("Categories Found", f"{len(detected)}/4")
608
+ with cols[1]:
609
+ if detected:
610
+ avg_score = np.mean([category_results[cat]['score'] for cat in detected])
611
+ st.metric("Average Score", f"{avg_score:.1f}/4")
612
+ else:
613
+ st.metric("Average Score", "N/A")
614
+ with cols[2]:
615
+ st.metric("Total Segments", len(segment_results))
616
+ with cols[3]:
617
+ quality = "Excellent" if len(detected) == 4 and avg_score >= 3.5 else "Good" if len(detected) >= 3 else "Needs Work"
618
+ st.metric("Overall", quality)
619
+
620
+ # Category breakdown
621
+ st.subheader("πŸ“‹ Category Analysis")
622
+ for cat in CATEGORIES.keys():
623
+ res = category_results[cat]
624
+ if res['detected']:
625
+ icon = "βœ…" if res['score'] >= 3 else "⚠️" if res['score'] >= 2 else "❌"
626
+ st.write(f"{icon} **{cat}**: Score {res['score']}/4 (Confidence: {res['confidence']:.1%})")
627
+ else:
628
+ st.write(f"❌ **{cat}**: Not detected")
629
+
630
+ # Segment details
631
+ st.subheader("πŸ“ Segment Details")
632
+ for seg in segment_results:
633
+ with st.expander(f"Segment {seg['segment_num']}: {seg['category']}"):
634
+ st.write(f"**Score:** {seg['score']}/4" if seg['score'] else "N/A")
635
+ st.write(f"**Confidence:** {seg['confidence']:.1%}")
636
+ st.write(f"**Text:** {seg['text'][:300]}...")
637
+
638
+ # Recommendations
639
+ st.subheader("πŸ’‘ Recommendations")
640
+ missing = [cat for cat, res in category_results.items() if not res['detected']]
641
+ low_score = [cat for cat, res in category_results.items()
642
+ if res['detected'] and res['score'] and res['score'] < 3]
643
+
644
+ if missing:
645
+ st.warning("**Missing Categories:**")
646
+ for cat in missing:
647
+ st.write(f"β€’ Add content for **{cat}**: {CATEGORIES[cat]['description']}")
648
+
649
+ if low_score:
650
+ st.info("**Areas to Improve:**")
651
+ for cat in low_score:
652
+ st.write(f"β€’ Strengthen **{cat}** (current score: {category_results[cat]['score']}/4)")
653
+
654
+ if not missing and not low_score:
655
+ st.success("Excellent work! All categories present with good scores.")
656
 
657
+ # View Rubrics Tab
658
+ with tab3:
659
+ st.header("Scoring Rubrics")
660
+
661
+ for category, info in CATEGORIES.items():
662
+ with st.expander(f"**{category}**"):
663
+ st.write(f"**Description:** {info['description']}")
664
+ st.write("**Scoring Criteria:**")
665
+ for score in [4, 3, 2, 1]:
666
+ st.write(f"β€’ **Score {score}:** {info['rubric'][score]}")
667
+ st.write(f"**Key Terms:** {', '.join(info['keywords'][:8])}")