anfastech commited on
Commit
22c0a89
·
1 Parent(s): 46bb16f

Fix: removing the unused legacy methods generate_target_transcript

Browse files
Files changed (2) hide show
  1. app.py +1 -1
  2. diagnosis/ai_engine/detect_stuttering.py +3 -152
app.py CHANGED
@@ -155,7 +155,7 @@ async def analyze_audio(
155
  except Exception as e:
156
  logger.warning(f"Could not clean up {temp_file}: {e}")
157
 
158
- @app.get("/")
159
  async def root():
160
  """API documentation"""
161
  return {
 
155
  except Exception as e:
156
  logger.warning(f"Could not clean up {temp_file}: {e}")
157
 
158
+ @app.get("/api")
159
  async def root():
160
  """API documentation"""
161
  return {
diagnosis/ai_engine/detect_stuttering.py CHANGED
@@ -670,155 +670,6 @@ class AdvancedStutterDetector:
670
  }
671
 
672
 
673
- # Legacy methods - kept for backward compatibility but may not work without additional model initialization
674
- # These methods reference models (xlsr, base, large) that are not initialized in __init__
675
- # The main analyze_audio() method uses the IndicWav2Vec Hindi model instead
676
-
677
- def generate_target_transcript(self, audio_file: str) -> str:
678
- """Generate expected transcript - Legacy method (uses IndicWav2Vec Hindi model)"""
679
- try:
680
- audio, sr = librosa.load(audio_file, sr=16000)
681
- transcript, _, _ = self._transcribe_with_timestamps(audio)
682
- return transcript
683
- except Exception as e:
684
- logger.error(f"Target transcript generation failed: {e}")
685
- return ""
686
-
687
- def transcribe_and_detect(self, audio_file: str, proper_transcript: str) -> Dict:
688
- """Transcribe audio and detect stuttering patterns - Legacy method"""
689
- try:
690
- audio, _ = librosa.load(audio_file, sr=16000)
691
- transcript, _, _ = self._transcribe_with_timestamps(audio)
692
-
693
- # Find stuttered sequences
694
- stuttered_chars = self.find_sequences_not_in_common(transcript, proper_transcript)
695
-
696
- # Calculate mismatch percentage
697
- total_mismatched = sum(len(segment) for segment in stuttered_chars)
698
- mismatch_percentage = (total_mismatched / len(proper_transcript)) * 100 if len(proper_transcript) > 0 else 0
699
- mismatch_percentage = min(round(mismatch_percentage), 100)
700
-
701
- return {
702
- 'transcription': transcript,
703
- 'stuttered_chars': stuttered_chars,
704
- 'mismatch_percentage': mismatch_percentage
705
- }
706
- except Exception as e:
707
- logger.error(f"Transcription failed: {e}")
708
- return {
709
- 'transcription': '',
710
- 'stuttered_chars': [],
711
- 'mismatch_percentage': 0
712
- }
713
-
714
- def calculate_stutter_timestamps(self, audio_file: str, proper_transcript: str) -> Tuple[float, List[Tuple[float, float]]]:
715
- """Calculate stutter timestamps - Legacy method (uses analyze_audio instead)"""
716
- try:
717
- # Use main analyze_audio method
718
- result = self.analyze_audio(audio_file, proper_transcript)
719
-
720
- # Extract timestamps from result
721
- timestamps = []
722
- for event in result.get('stutter_timestamps', []):
723
- timestamps.append((event['start'], event['end']))
724
-
725
- ctc_score = result.get('ctc_loss_score', 0.0)
726
- return float(ctc_score), timestamps
727
- except Exception as e:
728
- logger.error(f"Timestamp calculation failed: {e}")
729
- return 0.0, []
730
-
731
-
732
- def find_max_common_characters(self, transcription1: str, transcript2: str) -> str:
733
- """Longest Common Subsequence algorithm"""
734
- m, n = len(transcription1), len(transcript2)
735
- lcs_matrix = [[0] * (n + 1) for _ in range(m + 1)]
736
-
737
- for i in range(1, m + 1):
738
- for j in range(1, n + 1):
739
- if transcription1[i - 1] == transcript2[j - 1]:
740
- lcs_matrix[i][j] = lcs_matrix[i - 1][j - 1] + 1
741
- else:
742
- lcs_matrix[i][j] = max(lcs_matrix[i - 1][j], lcs_matrix[i][j - 1])
743
-
744
- # Backtrack to find LCS
745
- lcs_characters = []
746
- i, j = m, n
747
- while i > 0 and j > 0:
748
- if transcription1[i - 1] == transcript2[j - 1]:
749
- lcs_characters.append(transcription1[i - 1])
750
- i -= 1
751
- j -= 1
752
- elif lcs_matrix[i - 1][j] > lcs_matrix[i][j - 1]:
753
- i -= 1
754
- else:
755
- j -= 1
756
-
757
- lcs_characters.reverse()
758
- return ''.join(lcs_characters)
759
-
760
-
761
- def find_sequences_not_in_common(self, transcription1: str, proper_transcript: str) -> List[str]:
762
- """Find stuttered character sequences"""
763
- common_characters = self.find_max_common_characters(transcription1, proper_transcript)
764
- sequences = []
765
- sequence = ""
766
- i, j = 0, 0
767
-
768
- while i < len(transcription1) and j < len(common_characters):
769
- if transcription1[i] == common_characters[j]:
770
- if sequence:
771
- sequences.append(sequence)
772
- sequence = ""
773
- i += 1
774
- j += 1
775
- else:
776
- sequence += transcription1[i]
777
- i += 1
778
-
779
- if sequence:
780
- sequences.append(sequence)
781
-
782
- return sequences
783
-
784
-
785
- def _calculate_total_duration(self, timestamps: List[Tuple[float, float]]) -> float:
786
- """Calculate total stuttering duration"""
787
- return sum(end - start for start, end in timestamps)
788
-
789
-
790
- def _calculate_frequency(self, timestamps: List[Tuple[float, float]], audio_file: str) -> float:
791
- """Calculate stutters per minute"""
792
- try:
793
- audio_duration = librosa.get_duration(path=audio_file)
794
- if audio_duration > 0:
795
- return (len(timestamps) / audio_duration) * 60
796
- return 0.0
797
- except:
798
- return 0.0
799
-
800
-
801
- def _determine_severity(self, mismatch_percentage: float) -> str:
802
- """Determine severity level"""
803
- if mismatch_percentage < 10:
804
- return 'none'
805
- elif mismatch_percentage < 25:
806
- return 'mild'
807
- elif mismatch_percentage < 50:
808
- return 'moderate'
809
- else:
810
- return 'severe'
811
-
812
-
813
- def _calculate_confidence(self, transcription_result: Dict, ctc_loss: float) -> float:
814
- """Calculate confidence score for the analysis"""
815
- # Lower mismatch and lower CTC loss = higher confidence
816
- mismatch_factor = 1 - (transcription_result['mismatch_percentage'] / 100)
817
- loss_factor = max(0, 1 - (ctc_loss / 10)) # Normalize loss
818
- confidence = (mismatch_factor + loss_factor) / 2
819
- return round(min(max(confidence, 0.0), 1.0), 2)
820
-
821
-
822
- # Model loader is now in a separate module: model_loader.py
823
- # This follows clean architecture principles - separation of concerns
824
- # Import using: from diagnosis.ai_engine.model_loader import get_stutter_detector
 
670
  }
671
 
672
 
673
+ # Model loader is now in a separate module: model_loader.py
674
+ # This follows clean architecture principles - separation of concerns
675
+ # Import using: from diagnosis.ai_engine.model_loader import get_stutter_detector