Spaces:
Sleeping
Sleeping
Akis Giannoukos commited on
Commit ·
d517324
1
Parent(s): 497441d
Added GPU decorator
Browse files- app.py +60 -50
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -19,6 +19,7 @@ from transformers import (
|
|
| 19 |
pipeline,
|
| 20 |
)
|
| 21 |
from gtts import gTTS
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
# ---------------------------
|
|
@@ -39,13 +40,17 @@ _gen_pipe = None
|
|
| 39 |
_tokenizer = None
|
| 40 |
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def get_asr_pipeline():
|
| 43 |
global _asr_pipe
|
| 44 |
if _asr_pipe is None:
|
| 45 |
_asr_pipe = pipeline(
|
| 46 |
"automatic-speech-recognition",
|
| 47 |
model=DEFAULT_ASR_MODEL_ID,
|
| 48 |
-
device=
|
| 49 |
)
|
| 50 |
return _asr_pipe
|
| 51 |
|
|
@@ -58,8 +63,8 @@ def get_textgen_pipeline():
|
|
| 58 |
task="text-generation",
|
| 59 |
model=DEFAULT_CHAT_MODEL_ID,
|
| 60 |
tokenizer=DEFAULT_CHAT_MODEL_ID,
|
| 61 |
-
device=
|
| 62 |
-
torch_dtype=torch.float32,
|
| 63 |
)
|
| 64 |
return _gen_pipe
|
| 65 |
|
|
@@ -334,6 +339,7 @@ def init_state() -> Tuple[List[Tuple[str, str]], Dict[str, Any], Dict[str, Any],
|
|
| 334 |
return chat_history, scores, meta, finished, turns
|
| 335 |
|
| 336 |
|
|
|
|
| 337 |
def process_turn(
|
| 338 |
audio_path: Optional[str],
|
| 339 |
text_input: Optional[str],
|
|
@@ -454,57 +460,61 @@ def reset_app():
|
|
| 454 |
# ---------------------------
|
| 455 |
# UI
|
| 456 |
# ---------------------------
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
|
|
|
|
|
|
| 465 |
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
| 483 |
-
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
|
| 505 |
-
|
| 506 |
|
|
|
|
| 507 |
|
|
|
|
| 508 |
if __name__ == "__main__":
|
| 509 |
# For local dev
|
| 510 |
demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
|
|
|
|
| 19 |
pipeline,
|
| 20 |
)
|
| 21 |
from gtts import gTTS
|
| 22 |
+
import spaces
|
| 23 |
|
| 24 |
|
| 25 |
# ---------------------------
|
|
|
|
| 40 |
_tokenizer = None
|
| 41 |
|
| 42 |
|
| 43 |
+
def _hf_device() -> int:
|
| 44 |
+
return 0 if torch.cuda.is_available() else -1
|
| 45 |
+
|
| 46 |
+
|
| 47 |
def get_asr_pipeline():
|
| 48 |
global _asr_pipe
|
| 49 |
if _asr_pipe is None:
|
| 50 |
_asr_pipe = pipeline(
|
| 51 |
"automatic-speech-recognition",
|
| 52 |
model=DEFAULT_ASR_MODEL_ID,
|
| 53 |
+
device=_hf_device(),
|
| 54 |
)
|
| 55 |
return _asr_pipe
|
| 56 |
|
|
|
|
| 63 |
task="text-generation",
|
| 64 |
model=DEFAULT_CHAT_MODEL_ID,
|
| 65 |
tokenizer=DEFAULT_CHAT_MODEL_ID,
|
| 66 |
+
device=_hf_device(),
|
| 67 |
+
torch_dtype=(torch.float16 if torch.cuda.is_available() else torch.float32),
|
| 68 |
)
|
| 69 |
return _gen_pipe
|
| 70 |
|
|
|
|
| 339 |
return chat_history, scores, meta, finished, turns
|
| 340 |
|
| 341 |
|
| 342 |
+
@spaces.GPU
|
| 343 |
def process_turn(
|
| 344 |
audio_path: Optional[str],
|
| 345 |
text_input: Optional[str],
|
|
|
|
| 460 |
# ---------------------------
|
| 461 |
# UI
|
| 462 |
# ---------------------------
|
| 463 |
+
@spaces.GPU
|
| 464 |
+
def create_demo():
|
| 465 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 466 |
+
gr.Markdown(
|
| 467 |
+
"""
|
| 468 |
+
### PHQ-9 Conversational Clinician Agent
|
| 469 |
+
Engage in a brief, empathetic conversation. Your audio is transcribed, analyzed, and used to infer PHQ-9 scores.
|
| 470 |
+
The system stops when confidence is high enough or any safety risk is detected. It does not provide therapy or emergency counseling.
|
| 471 |
+
"""
|
| 472 |
+
)
|
| 473 |
|
| 474 |
+
with gr.Row():
|
| 475 |
+
chatbot = gr.Chatbot(height=400, type="tuples")
|
| 476 |
+
with gr.Column():
|
| 477 |
+
score_json = gr.JSON(label="PHQ-9 Assessment (live)")
|
| 478 |
+
severity_label = gr.Label(label="Severity")
|
| 479 |
+
threshold = gr.Slider(0.5, 1.0, value=CONFIDENCE_THRESHOLD_DEFAULT, step=0.05, label="Confidence Threshold (stop when min ≥ τ)")
|
| 480 |
+
tts_enable = gr.Checkbox(label="Speak clinician responses (TTS)", value=USE_TTS_DEFAULT)
|
| 481 |
+
tts_audio = gr.Audio(label="Clinician voice", interactive=False)
|
| 482 |
+
|
| 483 |
+
with gr.Row():
|
| 484 |
+
audio = gr.Audio(sources=["microphone"], type="filepath", label="Speak your response (or use text)")
|
| 485 |
+
text = gr.Textbox(lines=2, placeholder="Optional: type your response instead of audio")
|
| 486 |
+
|
| 487 |
+
with gr.Row():
|
| 488 |
+
send_btn = gr.Button("Send")
|
| 489 |
+
reset_btn = gr.Button("Reset")
|
| 490 |
+
|
| 491 |
+
# App state
|
| 492 |
+
chat_state = gr.State()
|
| 493 |
+
scores_state = gr.State()
|
| 494 |
+
meta_state = gr.State()
|
| 495 |
+
finished_state = gr.State()
|
| 496 |
+
turns_state = gr.State()
|
| 497 |
+
|
| 498 |
+
# Initialize on load
|
| 499 |
+
def _on_load():
|
| 500 |
+
return init_state()
|
| 501 |
+
|
| 502 |
+
demo.load(_on_load, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
|
| 503 |
+
|
| 504 |
+
# Wire interactions
|
| 505 |
+
send_btn.click(
|
| 506 |
+
fn=process_turn,
|
| 507 |
+
inputs=[audio, text, chatbot, threshold, tts_enable, finished_state, turns_state, scores_state, meta_state],
|
| 508 |
+
outputs=[chatbot, score_json, severity_label, finished_state, turns_state, audio, text, tts_audio],
|
| 509 |
+
queue=True,
|
| 510 |
+
api_name="message",
|
| 511 |
+
)
|
| 512 |
|
| 513 |
+
reset_btn.click(fn=reset_app, inputs=None, outputs=[chatbot, scores_state, meta_state, finished_state, turns_state])
|
| 514 |
|
| 515 |
+
return demo
|
| 516 |
|
| 517 |
+
demo = create_demo()
|
| 518 |
if __name__ == "__main__":
|
| 519 |
# For local dev
|
| 520 |
demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")))
|
requirements.txt
CHANGED
|
@@ -9,4 +9,5 @@ numpy>=1.26.4
|
|
| 9 |
scipy>=1.11.4
|
| 10 |
protobuf>=4.25.3
|
| 11 |
gTTS>=2.5.3
|
|
|
|
| 12 |
|
|
|
|
| 9 |
scipy>=1.11.4
|
| 10 |
protobuf>=4.25.3
|
| 11 |
gTTS>=2.5.3
|
| 12 |
+
spaces>=0.27.1
|
| 13 |
|