anfastech commited on
Commit
e765887
·
1 Parent(s): 3c951bc

Fix: Gradio UI (default landing page) #2

Browse files
Files changed (1) hide show
  1. app.py +4 -35
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # app.py
2
  import logging
3
  import os
4
  import sys
@@ -19,7 +18,7 @@ logger = logging.getLogger(__name__)
19
  # Add project root to path
20
  sys.path.insert(0, str(Path(__file__).parent))
21
 
22
- # Import detector using model loader (clean architecture)
23
  try:
24
  from diagnosis.ai_engine.model_loader import get_stutter_detector
25
  logger.info("✅ Successfully imported model loader")
@@ -59,9 +58,7 @@ async def startup_event():
59
  raise
60
 
61
  def gradio_analyze(audio_path, transcript=""):
62
- """
63
- Analyze audio for stuttering using Gradio interface
64
- """
65
  if not detector:
66
  return {"error": "Models not loaded yet. Please try again later."}
67
  try:
@@ -82,7 +79,7 @@ gradio_app = gr.Interface(
82
  description="Upload an audio file and optionally provide a transcript to analyze for stuttering."
83
  )
84
 
85
- # Mount Gradio app to FastAPI
86
  gr.mount_gradio_app(app, gradio_app, path="/")
87
 
88
  @app.get("/health")
@@ -133,7 +130,6 @@ async def analyze_audio(
133
  logger.info(f"🔄 Analyzing audio with transcript: '{transcript[:50] if transcript else '(empty)'}...'")
134
  result = detector.analyze_audio(temp_file, transcript)
135
 
136
- # Log transcript values from result
137
  actual = result.get('actual_transcript', '')
138
  target = result.get('target_transcript', '')
139
  logger.info(f"✅ Analysis complete: severity={result['severity']}, mismatch={result['mismatch_percentage']}%")
@@ -155,33 +151,6 @@ async def analyze_audio(
155
  except Exception as e:
156
  logger.warning(f"Could not clean up {temp_file}: {e}")
157
 
158
- @app.get("/api")
159
- async def root():
160
- """API documentation"""
161
- return {
162
- "name": "SLAQ Stutter Detector API",
163
- "version": "1.0.0",
164
- "status": "running",
165
- "endpoints": {
166
- "health": "GET /health",
167
- "analyze": "POST /analyze (multipart: audio file + optional transcript field)",
168
- "docs": "GET /docs (interactive API docs)",
169
- "gradio": "GET /gradio (web UI for stutter detection)"
170
- },
171
- "models": {
172
- "base": "facebook/wav2vec2-base-960h",
173
- "large": "facebook/wav2vec2-large-960h-lv60-self",
174
- "xlsr": "jonatasgrosman/wav2vec2-large-xlsr-53-english"
175
- }
176
- }
177
-
178
- @app.get("/")
179
- async def root():
180
- """Redirect to Gradio UI - handled by gr.mount_gradio_app"""
181
- # This will be overridden by Gradio mounting at "/"
182
- # Users will see the Gradio UI when visiting "/"
183
- pass
184
-
185
  if __name__ == "__main__":
186
  import uvicorn
187
  logger.info("🚀 Starting SLAQ Stutter Detector API...")
@@ -190,4 +159,4 @@ if __name__ == "__main__":
190
  host="0.0.0.0",
191
  port=7860,
192
  log_level="info"
193
- )
 
 
1
  import logging
2
  import os
3
  import sys
 
18
  # Add project root to path
19
  sys.path.insert(0, str(Path(__file__).parent))
20
 
21
+ # Import detector using model loader
22
  try:
23
  from diagnosis.ai_engine.model_loader import get_stutter_detector
24
  logger.info("✅ Successfully imported model loader")
 
58
  raise
59
 
60
  def gradio_analyze(audio_path, transcript=""):
61
+ """Analyze audio for stuttering using Gradio interface"""
 
 
62
  if not detector:
63
  return {"error": "Models not loaded yet. Please try again later."}
64
  try:
 
79
  description="Upload an audio file and optionally provide a transcript to analyze for stuttering."
80
  )
81
 
82
+ # Mount Gradio app to FastAPI at root path
83
  gr.mount_gradio_app(app, gradio_app, path="/")
84
 
85
  @app.get("/health")
 
130
  logger.info(f"🔄 Analyzing audio with transcript: '{transcript[:50] if transcript else '(empty)'}...'")
131
  result = detector.analyze_audio(temp_file, transcript)
132
 
 
133
  actual = result.get('actual_transcript', '')
134
  target = result.get('target_transcript', '')
135
  logger.info(f"✅ Analysis complete: severity={result['severity']}, mismatch={result['mismatch_percentage']}%")
 
151
  except Exception as e:
152
  logger.warning(f"Could not clean up {temp_file}: {e}")
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  if __name__ == "__main__":
155
  import uvicorn
156
  logger.info("🚀 Starting SLAQ Stutter Detector API...")
 
159
  host="0.0.0.0",
160
  port=7860,
161
  log_level="info"
162
+ )