theNorms commited on
Commit
03e41ea
·
verified ·
1 Parent(s): 1e02ebb

Upload embodiment_pipeline.py

Browse files
Files changed (1) hide show
  1. models/embodiment_pipeline.py +104 -0
models/embodiment_pipeline.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from dataclasses import dataclass
3
+ from typing import Dict, Any, Optional
4
+
5
+ try:
6
+ from TTS.api import TTS
7
+ except ImportError:
8
+ TTS = None
9
+
10
+ try:
11
+ import whisper
12
+ except ImportError:
13
+ whisper = None
14
+
15
+
16
+ @dataclass
17
+ class AvatarState:
18
+ facial_expression: str = "neutral"
19
+ gaze: str = "forward"
20
+ gesture: str = "idle"
21
+ posture: str = "balanced"
22
+
23
+
24
+ class EmbodimentSynchronizer:
25
+ def __init__(self, tts_model: str = "tts_models/en/ljspeech/tacotron2-DDC_ph"):
26
+ self.avatar_state = AvatarState()
27
+ self.tts = TTS(tts_model) if TTS is not None else None
28
+
29
+ def map_prosody(self, prosody: Dict[str, float]) -> AvatarState:
30
+ energy = prosody.get("energy", 0.5)
31
+ pitch = prosody.get("pitch", 0.5)
32
+ focus = prosody.get("focus", 0.5)
33
+
34
+ if energy > 0.8:
35
+ self.avatar_state.facial_expression = "excited"
36
+ self.avatar_state.gesture = "open_hands"
37
+ elif energy < 0.3:
38
+ self.avatar_state.facial_expression = "calm"
39
+ self.avatar_state.gesture = "hands_down"
40
+ else:
41
+ self.avatar_state.facial_expression = "attentive"
42
+ self.avatar_state.gesture = "subtle"
43
+
44
+ if pitch > 0.7:
45
+ self.avatar_state.gaze = "upward"
46
+ elif pitch < 0.3:
47
+ self.avatar_state.gaze = "downward"
48
+ else:
49
+ self.avatar_state.gaze = "forward"
50
+
51
+ if focus > 0.8:
52
+ self.avatar_state.posture = "lean_forward"
53
+ else:
54
+ self.avatar_state.posture = "balanced"
55
+
56
+ return self.avatar_state
57
+
58
+ def synthesize_audio(self, text: str, emotion_weight: float = 0.5) -> Optional[bytes]:
59
+ if self.tts is None:
60
+ raise RuntimeError("Coqui TTS is not installed. Install it with `pip install TTS`." )
61
+ wav = self.tts.tts(text=text, speaker="alloy", sample_rate=24000)
62
+ return wav
63
+
64
+ def synchronize(self, text: str, prosody: Dict[str, float], qualia_strength: float) -> Dict[str, Any]:
65
+ avatar = self.map_prosody(prosody)
66
+ audio = self.synthesize_audio(text=text, emotion_weight=qualia_strength)
67
+ return {
68
+ "avatar_state": avatar,
69
+ "audio": audio,
70
+ "text": text,
71
+ }
72
+
73
+
74
+ class StreamingVoicePipeline:
75
+ def __init__(self, stt_model_name: str = "base", tts_model: str = "tts_models/en/ljspeech/tacotron2-DDC_ph"):
76
+ self.whisper = whisper.load_model(stt_model_name) if whisper is not None else None
77
+ self.embodiment = EmbodimentSynchronizer(tts_model=tts_model)
78
+ self.vad_active = True
79
+
80
+ def transcribe_audio(self, audio_path: str) -> str:
81
+ if self.whisper is None:
82
+ raise RuntimeError("Whisper is not installed. Install it with `pip install openai-whisper`.")
83
+ result = self.whisper.transcribe(audio_path)
84
+ return result.get("text", "")
85
+
86
+ async def process_turn(self, audio_path: str, prosody: Dict[str, float], qualia_strength: float) -> Dict[str, Any]:
87
+ transcript = self.transcribe_audio(audio_path)
88
+ output_text = f"Processed: {transcript}"
89
+ return self.embodiment.synchronize(output_text, prosody, qualia_strength)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ import argparse
94
+
95
+ parser = argparse.ArgumentParser(description="Embodiment and voice pipeline for Syntelligence.")
96
+ parser.add_argument("--audio", help="Path to input audio file.")
97
+ args = parser.parse_args()
98
+
99
+ if args.audio is None:
100
+ print("Provide --audio to process a voice turn.")
101
+ else:
102
+ pipeline = StreamingVoicePipeline()
103
+ result = asyncio.run(pipeline.process_turn(args.audio, {"energy": 0.7, "pitch": 0.5, "focus": 0.8}, qualia_strength=0.7))
104
+ print(result)