Akis Giannoukos commited on
Commit
d517324
·
1 Parent(s): 497441d

Added GPU decorator

Browse files
Files changed (2) hide show
  1. app.py +60 -50
  2. requirements.txt +1 -0
app.py CHANGED
@@ -19,6 +19,7 @@ from transformers import (
19
  pipeline,
20
  )
21
  from gtts import gTTS
 
22
 
23
 
24
  # ---------------------------
@@ -39,13 +40,17 @@ _gen_pipe = None
39
  _tokenizer = None
40
 
41
 
 
 
 
 
42
  def get_asr_pipeline():
43
  global _asr_pipe
44
  if _asr_pipe is None:
45
  _asr_pipe = pipeline(
46
  "automatic-speech-recognition",
47
  model=DEFAULT_ASR_MODEL_ID,
48
- device=-1,
49
  )
50
  return _asr_pipe
51
 
@@ -58,8 +63,8 @@ def get_textgen_pipeline():
58
  task="text-generation",
59
  model=DEFAULT_CHAT_MODEL_ID,
60
  tokenizer=DEFAULT_CHAT_MODEL_ID,
61
- device=-1,
62
- torch_dtype=torch.float32,
63
  )
64
  return _gen_pipe
65
 
@@ -334,6 +339,7 @@ def init_state() -> Tuple[List[Tuple[str, str]], Dict[str, Any], Dict[str, Any],
334
  return chat_history, scores, meta, finished, turns
335
 
336
 
 
337
  def process_turn(
338
  audio_path: Optional[str],
339
  text_input: Optional[str],
@@ -454,57 +460,61 @@ def reset_app():
454
  # ---------------------------
455
  # UI
456
  # ---------------------------
457
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
458
- gr.Markdown(
459
- """
460
- ### PHQ-9 Conversational Clinician Agent
461
- Engage in a brief, empathetic conversation. Your audio is transcribed, analyzed, and used to infer PHQ-9 scores.
462
- The system stops when confidence is high enough or any safety risk is detected. It does not provide therapy or emergency counseling.
463
- """
464
- )
 
 
465
 
466
- with gr.Row():
467
- chatbot = gr.Chatbot(height=400, type="tuples")
468
- with gr.Column():
469
- score_json = gr.JSON(label="PHQ-9 Assessment (live)")
470
- severity_label = gr.Label(label="Severity")
471
- threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)")
472
- tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT)
473
- tts_audio = gr.Audio(label="Clinician voice", interactive=False)
474
-
475
- with gr.Row():
476
- audio = gr.Audio(sources=["microphone"], type="filepath", label="Speak your response (or use text)")
477
- text = gr.Textbox(lines=2, placeholder="Optional: type your response instead of audio")
478
-
479
- with gr.Row():
480
- send_btn = gr.Button("Send")
481
- reset_btn = gr.Button("Reset")
482
-
483
- # App state
484
- chat_state = gr.State()
485
- scores_state = gr.State()
486
- meta_state = gr.State()
487
- finished_state = gr.State()
488
- turns_state = gr.State()
489
-
490
- # Initialize on load
491
- def _on_load():
492
- return init_state()
493
-
494
- demo.load(_on_load, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
495
-
496
- # Wire interactions
497
- send_btn.click(
498
- fn=process_turn,
499
- inputs=[audio, text, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state],
500
- outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio, text, tts_audio],
501
- queue=True,
502
- api_name="message",
503
- )
504
 
505
- reset_btn.click(fn=reset_app, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
506
 
 
507
 
 
508
  if __name__ == "__main__":
509
  # For local dev
510
  demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
 
19
  pipeline,
20
  )
21
  from gtts import gTTS
22
+ import spaces
23
 
24
 
25
  # ---------------------------
 
40
  _tokenizer = None
41
 
42
 
43
+ def _hf_device() -> int:
44
+ return 0 if torch.cuda.is_available() else -1
45
+
46
+
47
  def get_asr_pipeline():
48
  global _asr_pipe
49
  if _asr_pipe is None:
50
  _asr_pipe = pipeline(
51
  "automatic-speech-recognition",
52
  model=DEFAULT_ASR_MODEL_ID,
53
+ device=_hf_device(),
54
  )
55
  return _asr_pipe
56
 
 
63
  task="text-generation",
64
  model=DEFAULT_CHAT_MODEL_ID,
65
  tokenizer=DEFAULT_CHAT_MODEL_ID,
66
+ device=_hf_device(),
67
+ torch_dtype=(torch.float16 if torch.cuda.is_available() else torch.float32),
68
  )
69
  return _gen_pipe
70
 
 
339
  return chat_history, scores, meta, finished, turns
340
 
341
 
342
+ @spaces.GPU
343
  def process_turn(
344
  audio_path: Optional[str],
345
  text_input: Optional[str],
 
460
  # ---------------------------
461
  # UI
462
  # ---------------------------
463
+ @spaces.GPU
464
+ def create_demo():
465
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
466
+ gr.Markdown(
467
+ """
468
+ ### PHQ-9 Conversational Clinician Agent
469
+ Engage in a brief, empathetic conversation. Your audio is transcribed, analyzed, and used to infer PHQ-9 scores.
470
+ The system stops when confidence is high enough or any safety risk is detected. It does not provide therapy or emergency counseling.
471
+ """
472
+ )
473
 
474
+ with gr.Row():
475
+ chatbot = gr.Chatbot(height=400, type="tuples")
476
+ with gr.Column():
477
+ score_json = gr.JSON(label="PHQ-9 Assessment (live)")
478
+ severity_label = gr.Label(label="Severity")
479
+ threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)")
480
+ tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT)
481
+ tts_audio = gr.Audio(label="Clinician voice", interactive=False)
482
+
483
+ with gr.Row():
484
+ audio = gr.Audio(sources=["microphone"], type="filepath", label="Speak your response (or use text)")
485
+ text = gr.Textbox(lines=2, placeholder="Optional: type your response instead of audio")
486
+
487
+ with gr.Row():
488
+ send_btn = gr.Button("Send")
489
+ reset_btn = gr.Button("Reset")
490
+
491
+ # App state
492
+ chat_state = gr.State()
493
+ scores_state = gr.State()
494
+ meta_state = gr.State()
495
+ finished_state = gr.State()
496
+ turns_state = gr.State()
497
+
498
+ # Initialize on load
499
+ def _on_load():
500
+ return init_state()
501
+
502
+ demo.load(_on_load, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
503
+
504
+ # Wire interactions
505
+ send_btn.click(
506
+ fn=process_turn,
507
+ inputs=[audio, text, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state],
508
+ outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio, text, tts_audio],
509
+ queue=True,
510
+ api_name="message",
511
+ )
512
 
513
+ reset_btn.click(fn=reset_app, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
514
 
515
+ return demo
516
 
517
+ demo = create_demo()
518
  if __name__ == "__main__":
519
  # For local dev
520
  demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
requirements.txt CHANGED
@@ -9,4 +9,5 @@ numpy>=1.26.4
9
  scipy>=1.11.4
10
  protobuf>=4.25.3
11
  gTTS>=2.5.3
 
12
 
 
9
  scipy>=1.11.4
10
  protobuf>=4.25.3
11
  gTTS>=2.5.3
12
+ spaces>=0.27.1
13