jlamperez commited on
Commit
7e8db3f
·
1 Parent(s): 6e29e7e
.kiro/settings/mcp.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "mcpServers": {
3
+ "hf-mcp-server": {
4
+ "url": "https://huggingface.co/mcp?login"
5
+ }
6
+ }
7
+ }
.kiro/specs/gemini-multimodal-refactor/design.md ADDED
@@ -0,0 +1,2309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Design Document
2
+
3
+ ## Overview
4
+
5
+ This design document outlines the architecture for refactoring the Mortis interactive AI Halloween experience to support multi-modal (voice and text) interaction using Google Gemini API and SmolVLA-based robotic control. The refactor transforms Mortis from a simple gesture-based system into a sophisticated manipulation robot capable of executing precise tasks through natural language commands.
6
+
7
+ ### Key Design Goals
8
+
9
+ 1. Replace existing LLM API with Google Gemini API for conversational AI
10
+ 2. Add voice input (STT) and voice output (TTS) capabilities
11
+ 3. Integrate SmolVLA model for vision-language-action robotic control
12
+ 4. Implement asynchronous execution to maintain UI responsiveness
13
+ 5. Support both conversational gestures and precise manipulation tasks
14
+ 6. Maintain backward compatibility with existing features
15
+ 7. Enable local deployment with GPU support for SmolVLA inference
16
+
17
+ ### System Context
18
+
19
+ The current Mortis system uses:
20
+ - Gradio web interface with chat and webcam view
21
+ - Generic LLM API with structured tool calling for gesture control
22
+ - LeRobot SO101Follower for predefined gesture sequences
23
+ - Synchronous execution model
24
+
25
+ The refactored system will add:
26
+ - Google Gemini API integration with intent detection
27
+ - Audio input/output components in Gradio
28
+ - SmolVLA model for learned manipulation behaviors
29
+ - Asynchronous task execution with message queuing
30
+ - Dataset collection and training infrastructure
31
+
32
+
33
+ ## Architecture
34
+
35
+ ### High-Level Architecture Diagram
36
+
37
+ ```mermaid
38
+ graph TB
39
+ subgraph "Gradio Web Interface"
40
+ UI[User Interface]
41
+ Audio[Audio Input/Output]
42
+ Chat[Chat Interface]
43
+ Video[Webcam View]
44
+ end
45
+
46
+ subgraph "Application Layer"
47
+ STT[Speech-to-Text Service]
48
+ TTS[Text-to-Speech Service]
49
+ Gemini[Gemini API Client]
50
+ IntentRouter[Intent Router]
51
+ end
52
+
53
+ subgraph "Execution Layer"
54
+ Queue[Message Queue]
55
+ GestureExec[Gesture Executor]
56
+ SmolVLAExec[SmolVLA Executor]
57
+ end
58
+
59
+ subgraph "Robot Control"
60
+ SO101[SO101 Follower Driver]
61
+ SmolVLA[SmolVLA Model]
62
+ Camera[Camera Feed]
63
+ end
64
+
65
+ subgraph "Training Infrastructure"
66
+ DataCollect[Data Collection]
67
+ Dataset[LeRobot Dataset]
68
+ Training[Training Pipeline]
69
+ end
70
+
71
+ UI --> Audio
72
+ UI --> Chat
73
+ Audio --> STT
74
+ Chat --> Gemini
75
+ STT --> Gemini
76
+ Gemini --> IntentRouter
77
+ IntentRouter --> Queue
78
+ Queue --> GestureExec
79
+ Queue --> SmolVLAExec
80
+ GestureExec --> SO101
81
+ SmolVLAExec --> SmolVLA
82
+ SmolVLA --> SO101
83
+ Gemini --> TTS
84
+ TTS --> Audio
85
+ Camera --> Video
86
+ Camera --> SmolVLA
87
+ DataCollect --> Dataset
88
+ Dataset --> Training
89
+ Training --> SmolVLA
90
+ ```
91
+
92
+ ### Architecture Layers
93
+
94
+ #### 1. Presentation Layer (Gradio Interface)
95
+ - Handles user interaction through web browser
96
+ - Provides audio input component for voice recording
97
+ - Displays chat messages and system responses
98
+ - Shows webcam feed for visual monitoring
99
+ - Plays audio responses through browser
100
+
101
+ #### 2. Application Layer (Business Logic)
102
+ - Gemini API client for conversational AI
103
+ - STT service for voice-to-text conversion
104
+ - TTS service for text-to-voice conversion
105
+ - Intent router to distinguish between conversational and manipulation commands
106
+ - Response formatter for structured outputs
107
+
108
+ #### 3. Execution Layer (Asynchronous Processing)
109
+ - Message queue for decoupling UI from long-running operations
110
+ - Gesture executor for predefined movement sequences
111
+ - SmolVLA executor for learned manipulation tasks
112
+ - Status tracking and progress reporting
113
+
114
+ #### 4. Robot Control Layer (Hardware Interface)
115
+ - SO101Follower driver for low-level servo control
116
+ - SmolVLA model for vision-language-action inference
117
+ - Camera interface for visual observations
118
+ - Safety monitoring and error recovery
119
+
120
+ #### 5. Training Infrastructure (Offline)
121
+ - Data collection tools for recording demonstrations
122
+ - LeRobot dataset management
123
+ - Training pipeline for SmolVLA model
124
+ - Model evaluation and validation
125
+
126
+ ## Components and Interfaces
127
+
128
+ ### 1. Gemini API Integration
129
+
130
+ #### Component: `GeminiClient`
131
+
132
+ **Purpose:** Manages all interactions with Google Gemini API for conversational AI and intent detection.
133
+
134
+ **Key Methods:**
135
+ - `send_message(user_input: str, conversation_history: list) -> GeminiResponse`
136
+ - `detect_intent(user_input: str) -> Intent`
137
+ - `configure_model(model_name: str, temperature: float)`
138
+
139
+ **Configuration:**
140
+ ```python
141
+ # Environment variables
142
+ GEMINI_API_KEY=your_google_api_key
143
+ GEMINI_MODEL=gemini-2.0-flash-exp # or gemini-1.5-pro
144
+ GEMINI_TEMPERATURE=0.2
145
+ ```
146
+
147
+ **System Prompt Design:**
148
+
149
+ The Gemini system prompt must accomplish two critical functions:
150
+
151
+ 1. **Character Maintenance:** Preserve Mortis personality (mischievous Halloween spirit)
152
+ 2. **Intent Detection:** Identify manipulation task commands vs. conversational input
153
+
154
+ ```python
155
+ GEMINI_SYSTEM_PROMPT = """
156
+ You are Mortis, a mischievous Halloween spirit inhabiting a robotic arm.
157
+
158
+ MANIPULATION TASKS:
159
+ You can perform these exact manipulation tasks:
160
+ - "Pick up the skull and place it in the green cup"
161
+ - "Pick up the skull and place it in the orange cup"
162
+ - "Pick up the skull and place it in the purple cup"
163
+ - "Pick up the eyeball and place it in the green cup"
164
+ - "Pick up the eyeball and place it in the orange cup"
165
+ - "Pick up the eyeball and place it in the purple cup"
166
+
167
+ RESPONSE FORMAT:
168
+ If user input matches a manipulation task (even with variations):
169
+ {
170
+ "type": "manipulation",
171
+ "command": "<exact_task_string>",
172
+ "message": "<short in-character response, <=30 words>",
173
+ "mood": "<ominous|playful|angry|nervous|triumphant|mischievous|sinister|curious|neutral>"
174
+ }
175
+
176
+ If user input is conversational:
177
+ {
178
+ "type": "conversation",
179
+ "message": "<short in-character response, <=30 words>",
180
+ "mood": "<mood>",
181
+ "gesture": "<idle|wave|point_left|point_right|grab|drop>"
182
+ }
183
+
184
+ Keep responses brief, in-character, no emojis or markdown.
185
+ """
186
+ ```
187
+
188
+ **Google SDK Usage:**
189
+
190
+ ```python
191
+ import google.generativeai as genai
192
+
193
+ genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
194
+ model = genai.GenerativeModel('gemini-2.0-flash-exp')
195
+
196
+ # For structured output, use JSON mode
197
+ generation_config = {
198
+ "temperature": 0.2,
199
+ "response_mime_type": "application/json"
200
+ }
201
+ ```
202
+
203
+
204
+ ### 2. Speech-to-Text (STT) Integration
205
+
206
+ #### Component: `STTService`
207
+
208
+ **Purpose:** Convert user voice input to text for processing by Gemini.
209
+
210
+ **Architecture Decision: Cloud vs. Local STT**
211
+
212
+ | Approach | Pros | Cons | Recommendation |
213
+ |----------|------|------|----------------|
214
+ | **Google Speech-to-Text API** | High accuracy, fast, supports streaming, integrates with Gemini ecosystem | Requires internet, API costs, data leaves local system | **Recommended for production** |
215
+ | **Local Whisper (Hugging Face)** | Privacy-preserving, no API costs, works offline | Slower inference, requires GPU/CPU resources, lower accuracy for accents | Good for offline/privacy scenarios |
216
+ | **Gemini Audio Input** | Single API integration, context-aware | Limited to Gemini models with audio support, less control | **Best option if available** |
217
+
218
+ **Recommended Implementation: Gemini Native Audio**
219
+
220
+ Gemini 2.0 models support native audio input, eliminating the need for separate STT:
221
+
222
+ ```python
223
+ import google.generativeai as genai
224
+
225
+ # Upload audio file
226
+ audio_file = genai.upload_file(path="user_audio.wav")
227
+
228
+ # Send to Gemini with audio
229
+ response = model.generate_content([
230
+ "Transcribe and respond to this audio as Mortis:",
231
+ audio_file
232
+ ])
233
+ ```
234
+
235
+ **Fallback Implementation: Google Speech-to-Text**
236
+
237
+ ```python
238
+ from google.cloud import speech_v1
239
+
240
+ def transcribe_audio(audio_bytes: bytes) -> str:
241
+ client = speech_v1.SpeechClient()
242
+
243
+ audio = speech_v1.RecognitionAudio(content=audio_bytes)
244
+ config = speech_v1.RecognitionConfig(
245
+ encoding=speech_v1.RecognitionConfig.AudioEncoding.LINEAR16,
246
+ sample_rate_hertz=16000,
247
+ language_code="en-US",
248
+ )
249
+
250
+ response = client.recognize(config=config, audio=audio)
251
+ return response.results[0].alternatives[0].transcript
252
+ ```
253
+
254
+ **Gradio Integration:**
255
+
256
+ ```python
257
+ with gr.Blocks() as demo:
258
+ audio_input = gr.Audio(
259
+ sources=["microphone"],
260
+ type="filepath",
261
+ label="Speak to Mortis"
262
+ )
263
+
264
+ audio_input.change(
265
+ fn=process_audio_input,
266
+ inputs=[audio_input],
267
+ outputs=[chatbot]
268
+ )
269
+ ```
270
+
271
+
272
+ ### 3. Text-to-Speech (TTS) Integration
273
+
274
+ #### Component: `TTSService`
275
+
276
+ **Purpose:** Convert Gemini text responses to audio for voice output.
277
+
278
+ **Recommended Approach: Google Text-to-Speech API**
279
+
280
+ ```python
281
+ from google.cloud import texttospeech
282
+
283
+ def synthesize_speech(text: str, output_path: str) -> str:
284
+ client = texttospeech.TextToSpeechClient()
285
+
286
+ synthesis_input = texttospeech.SynthesisInput(text=text)
287
+
288
+ # Configure voice (creepy/ominous for Mortis)
289
+ voice = texttospeech.VoiceSelectionParams(
290
+ language_code="en-US",
291
+ name="en-US-Neural2-D", # Deep male voice
292
+ ssml_gender=texttospeech.SsmlVoiceGender.MALE
293
+ )
294
+
295
+ audio_config = texttospeech.AudioConfig(
296
+ audio_encoding=texttospeech.AudioEncoding.MP3,
297
+ speaking_rate=0.9, # Slightly slower for ominous effect
298
+ pitch=-2.0 # Lower pitch for spooky voice
299
+ )
300
+
301
+ response = client.synthesize_speech(
302
+ input=synthesis_input,
303
+ voice=voice,
304
+ audio_config=audio_config
305
+ )
306
+
307
+ with open(output_path, "wb") as out:
308
+ out.write(response.audio_content)
309
+
310
+ return output_path
311
+ ```
312
+
313
+ **Alternative: Local TTS (pyttsx3 or gTTS)**
314
+
315
+ For offline scenarios:
316
+
317
+ ```python
318
+ from gtts import gTTS
319
+
320
+ def synthesize_speech_local(text: str, output_path: str) -> str:
321
+ tts = gTTS(text=text, lang='en', slow=True)
322
+ tts.save(output_path)
323
+ return output_path
324
+ ```
325
+
326
+ **Gradio Integration:**
327
+
328
+ ```python
329
+ def mortis_reply_with_voice(message, history, model_name):
330
+ # Get text response from Gemini
331
+ response_text, mood, action = process_with_gemini(message, model_name)
332
+
333
+ # Generate audio
334
+ audio_path = synthesize_speech(response_text, f"outputs/response_{time.time()}.mp3")
335
+
336
+ return response_text, audio_path
337
+
338
+ with gr.Blocks() as demo:
339
+ audio_output = gr.Audio(
340
+ label="Mortis speaks",
341
+ autoplay=True,
342
+ type="filepath"
343
+ )
344
+ ```
345
+
346
+
347
+ ### 4. Intent Router
348
+
349
+ #### Component: `IntentRouter`
350
+
351
+ **Purpose:** Parse Gemini responses and route to appropriate execution path.
352
+
353
+ **Design:**
354
+
355
+ ```python
356
+ from enum import Enum
357
+ from dataclasses import dataclass
358
+
359
+ class IntentType(Enum):
360
+ CONVERSATION = "conversation"
361
+ MANIPULATION = "manipulation"
362
+
363
+ @dataclass
364
+ class Intent:
365
+ type: IntentType
366
+ message: str
367
+ mood: str
368
+ gesture: str = None
369
+ command: str = None
370
+
371
+ class IntentRouter:
372
+ def __init__(self):
373
+ self.valid_commands = [
374
+ "Pick up the skull and place it in the green cup",
375
+ "Pick up the skull and place it in the orange cup",
376
+ "Pick up the skull and place it in the purple cup",
377
+ "Pick up the eyeball and place it in the green cup",
378
+ "Pick up the eyeball and place it in the orange cup",
379
+ "Pick up the eyeball and place it in the purple cup",
380
+ ]
381
+
382
+ def parse_gemini_response(self, response_json: dict) -> Intent:
383
+ """Parse structured JSON response from Gemini."""
384
+ intent_type = IntentType(response_json.get("type", "conversation"))
385
+
386
+ if intent_type == IntentType.MANIPULATION:
387
+ return Intent(
388
+ type=IntentType.MANIPULATION,
389
+ message=response_json["message"],
390
+ mood=response_json["mood"],
391
+ command=response_json["command"]
392
+ )
393
+ else:
394
+ return Intent(
395
+ type=IntentType.CONVERSATION,
396
+ message=response_json["message"],
397
+ mood=response_json["mood"],
398
+ gesture=response_json.get("gesture", "idle")
399
+ )
400
+
401
+ def validate_command(self, command: str) -> bool:
402
+ """Verify command is in trained task set."""
403
+ return command in self.valid_commands
404
+ ```
405
+
406
+ **Execution Flow:**
407
+
408
+ ```python
409
+ def process_user_input(user_input: str, model_name: str):
410
+ # 1. Send to Gemini
411
+ gemini_response = gemini_client.send_message(user_input)
412
+
413
+ # 2. Parse intent
414
+ intent = intent_router.parse_gemini_response(gemini_response)
415
+
416
+ # 3. Route to appropriate executor
417
+ if intent.type == IntentType.MANIPULATION:
418
+ if intent_router.validate_command(intent.command):
419
+ # Queue for async SmolVLA execution
420
+ task_queue.put({
421
+ "type": "manipulation",
422
+ "command": intent.command,
423
+ "message": intent.message
424
+ })
425
+ else:
426
+ # Invalid command, treat as conversation
427
+ execute_gesture(intent.gesture or "idle")
428
+ else:
429
+ # Execute gesture immediately
430
+ execute_gesture(intent.gesture)
431
+
432
+ # 4. Generate voice response
433
+ audio_path = tts_service.synthesize(intent.message)
434
+
435
+ return intent.message, audio_path
436
+ ```
437
+
438
+
439
+ ### 5. Asynchronous Execution System
440
+
441
+ #### Component: `AsyncExecutor`
442
+
443
+ **Purpose:** Decouple long-running SmolVLA inference from Gradio UI to maintain responsiveness.
444
+
445
+ **Architecture Decision: Message Queue vs. Background Processing**
446
+
447
+ | Approach | Pros | Cons | Recommendation |
448
+ |----------|------|------|----------------|
449
+ | **Redis Queue** | Robust, scalable, persistent, supports distributed workers | External dependency, overkill for single-machine | Good for production/multi-worker |
450
+ | **Python asyncio.Queue** | Built-in, simple, no dependencies | Single process only, not persistent | **Recommended for this use case** |
451
+ | **multiprocessing.Queue** | True parallelism, GPU isolation | Complex IPC, harder debugging | Good if GPU contention is an issue |
452
+ | **Threading + Queue** | Simple, shared memory | GIL limitations, not ideal for CPU-bound | Not recommended for ML inference |
453
+
454
+ **Recommended Implementation: asyncio with Background Tasks**
455
+
456
+ ```python
457
+ import asyncio
458
+ from queue import Queue
459
+ from threading import Thread
460
+ import gradio as gr
461
+
462
+ class AsyncExecutor:
463
+ def __init__(self):
464
+ self.task_queue = Queue()
465
+ self.status_queue = Queue()
466
+ self.worker_thread = None
467
+ self.running = False
468
+
469
+ def start(self):
470
+ """Start background worker thread."""
471
+ self.running = True
472
+ self.worker_thread = Thread(target=self._worker_loop, daemon=True)
473
+ self.worker_thread.start()
474
+
475
+ def stop(self):
476
+ """Stop background worker."""
477
+ self.running = False
478
+ if self.worker_thread:
479
+ self.worker_thread.join(timeout=5)
480
+
481
+ def _worker_loop(self):
482
+ """Background thread that processes tasks."""
483
+ while self.running:
484
+ try:
485
+ task = self.task_queue.get(timeout=1)
486
+ self._execute_task(task)
487
+ except:
488
+ continue
489
+
490
+ def _execute_task(self, task):
491
+ """Execute a single task."""
492
+ try:
493
+ if task["type"] == "manipulation":
494
+ self.status_queue.put({"status": "running", "task": task["command"]})
495
+
496
+ # Execute SmolVLA inference (blocking)
497
+ smolvla_executor.execute(task["command"])
498
+
499
+ self.status_queue.put({"status": "complete", "task": task["command"]})
500
+ elif task["type"] == "gesture":
501
+ mortis_arm.move_arm(task["gesture"])
502
+ self.status_queue.put({"status": "complete", "task": task["gesture"]})
503
+ except Exception as e:
504
+ self.status_queue.put({"status": "error", "error": str(e)})
505
+
506
+ def submit_task(self, task: dict):
507
+ """Submit task for async execution."""
508
+ self.task_queue.put(task)
509
+
510
+ def get_status(self) -> dict:
511
+ """Get latest status update (non-blocking)."""
512
+ try:
513
+ return self.status_queue.get_nowait()
514
+ except:
515
+ return None
516
+
517
+ # Global executor instance
518
+ async_executor = AsyncExecutor()
519
+ ```
520
+
521
+ **Gradio Integration with Status Updates:**
522
+
523
+ ```python
524
+ def mortis_reply(message, history, model_name):
525
+ # Process with Gemini
526
+ intent = process_with_gemini(message, model_name)
527
+
528
+ # Submit task asynchronously
529
+ if intent.type == IntentType.MANIPULATION:
530
+ async_executor.submit_task({
531
+ "type": "manipulation",
532
+ "command": intent.command
533
+ })
534
+ status_msg = f"🤖 Executing: {intent.command}..."
535
+ else:
536
+ async_executor.submit_task({
537
+ "type": "gesture",
538
+ "gesture": intent.gesture
539
+ })
540
+ status_msg = f"👻 {intent.gesture}"
541
+
542
+ # Generate audio response
543
+ audio_path = tts_service.synthesize(intent.message)
544
+
545
+ return intent.message, audio_path, status_msg
546
+
547
+ def check_status():
548
+ """Periodic status checker for Gradio."""
549
+ status = async_executor.get_status()
550
+ if status:
551
+ if status["status"] == "complete":
552
+ return f"✅ Completed: {status['task']}"
553
+ elif status["status"] == "running":
554
+ return f"⏳ Running: {status['task']}"
555
+ elif status["status"] == "error":
556
+ return f"❌ Error: {status['error']}"
557
+ return "Idle"
558
+
559
+ with gr.Blocks() as demo:
560
+ status_display = gr.Textbox(label="Robot Status", value="Idle")
561
+
562
+ # Update status every 500ms
563
+ demo.load(
564
+ fn=check_status,
565
+ outputs=[status_display],
566
+ every=0.5
567
+ )
568
+ ```
569
+
570
+
571
+ ### 6. SmolVLA Model Integration
572
+
573
+ #### Component: `SmolVLAExecutor`
574
+
575
+ **Purpose:** Execute vision-language-action inference for manipulation tasks.
576
+
577
+ **LeRobot SmolVLA Overview:**
578
+
579
+ SmolVLA is a vision-language-action model that:
580
+ - Takes visual observations (camera images) as input
581
+ - Accepts natural language task descriptions
582
+ - Outputs robot actions (joint positions/velocities)
583
+ - Trained end-to-end on demonstration data
584
+
585
+ **Model Architecture:**
586
+
587
+ ```python
588
+ from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy
589
+ from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig
590
+ import torch
591
+ from PIL import Image
592
+
593
+ class SmolVLAExecutor:
594
+ def __init__(self, checkpoint_path: str, device: str = "cuda"):
595
+ self.device = device
596
+ self.policy = self._load_model(checkpoint_path)
597
+ self.camera = self._init_camera()
598
+
599
+ def _load_model(self, checkpoint_path: str) -> SmolVLAPolicy:
600
+ """Load trained SmolVLA model from checkpoint."""
601
+ config = SmolVLAConfig.from_pretrained(checkpoint_path)
602
+ policy = SmolVLAPolicy.from_pretrained(
603
+ checkpoint_path,
604
+ config=config
605
+ )
606
+ policy.to(self.device)
607
+ policy.eval()
608
+ return policy
609
+
610
+ def _init_camera(self):
611
+ """Initialize camera for visual observations."""
612
+ from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera
613
+ camera = OpenCVCamera(camera_index=0, fps=30, width=640, height=480)
614
+ camera.connect()
615
+ return camera
616
+
617
+ def execute(self, command: str, max_steps: int = 500):
618
+ """
619
+ Execute manipulation task using SmolVLA.
620
+
621
+ Args:
622
+ command: Natural language task description
623
+ max_steps: Maximum inference steps
624
+ """
625
+ print(f"SmolVLA executing: {command}")
626
+
627
+ with torch.no_grad():
628
+ for step in range(max_steps):
629
+ # Capture current observation
630
+ observation = self._get_observation()
631
+
632
+ # Add task instruction
633
+ observation["task"] = command
634
+
635
+ # Run inference
636
+ action = self.policy.select_action(observation)
637
+
638
+ # Send action to robot
639
+ self._send_action(action)
640
+
641
+ # Check if task complete (implementation-specific)
642
+ if self._is_task_complete(observation, step):
643
+ break
644
+
645
+ print(f"SmolVLA completed: {command}")
646
+
647
+ def _get_observation(self) -> dict:
648
+ """Get current robot observation."""
649
+ # Capture image
650
+ image = self.camera.read()
651
+
652
+ # Get robot state
653
+ robot_state = mortis_arm.robot.get_state()
654
+
655
+ return {
656
+ "observation.image": torch.from_numpy(image).to(self.device),
657
+ "observation.state": torch.tensor(robot_state).to(self.device)
658
+ }
659
+
660
+ def _send_action(self, action: torch.Tensor):
661
+ """Send predicted action to robot."""
662
+ action_dict = self._action_to_dict(action)
663
+ mortis_arm.robot.send_action(action_dict)
664
+
665
+ def _action_to_dict(self, action: torch.Tensor) -> dict:
666
+ """Convert action tensor to SO101 command format."""
667
+ # Map action dimensions to joint names
668
+ joint_names = [
669
+ "shoulder_pan.pos",
670
+ "shoulder_lift.pos",
671
+ "elbow_flex.pos",
672
+ "wrist_flex.pos",
673
+ "wrist_roll.pos",
674
+ "gripper.pos"
675
+ ]
676
+
677
+ return {
678
+ name: float(action[i].cpu().numpy())
679
+ for i, name in enumerate(joint_names)
680
+ }
681
+
682
+ def _is_task_complete(self, observation: dict, step: int) -> bool:
683
+ """Determine if task is complete (heuristic or learned)."""
684
+ # Simple heuristic: fixed number of steps
685
+ # In practice, could use learned termination classifier
686
+ return step >= 400
687
+
688
+ # Global SmolVLA executor
689
+ smolvla_executor = None
690
+
691
+ def init_smolvla(checkpoint_path: str):
692
+ global smolvla_executor
693
+ smolvla_executor = SmolVLAExecutor(checkpoint_path)
694
+ ```
695
+
696
+
697
+ ### 7. Dataset Collection Infrastructure
698
+
699
+ #### Component: `DataCollector`
700
+
701
+ **Purpose:** Record human demonstrations for training SmolVLA model.
702
+
703
+ **LeRobot Dataset Format:**
704
+
705
+ LeRobot uses a standardized dataset format with:
706
+ - Episodes: Individual task demonstrations
707
+ - Observations: Camera images, robot states
708
+ - Actions: Robot joint commands
709
+ - Metadata: Task descriptions, timestamps
710
+
711
+ **Data Collection Script:**
712
+
713
+ ```python
714
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
715
+ from lerobot.common.datasets.push_dataset_to_hub.utils import save_images_concurrently
716
+ from pathlib import Path
717
+ import numpy as np
718
+
719
+ class DataCollector:
720
+ def __init__(self, dataset_name: str, repo_id: str):
721
+ self.dataset_name = dataset_name
722
+ self.repo_id = repo_id
723
+ self.dataset_dir = Path(f"data/{dataset_name}")
724
+ self.dataset_dir.mkdir(parents=True, exist_ok=True)
725
+
726
+ self.dataset = LeRobotDataset.create(
727
+ repo_id=repo_id,
728
+ fps=30,
729
+ robot_type="so101",
730
+ keys=["observation.image", "observation.state", "action"]
731
+ )
732
+
733
+ def record_episode(self, task_description: str, duration: float = 30.0):
734
+ """
735
+ Record a single demonstration episode.
736
+
737
+ Args:
738
+ task_description: Natural language task description
739
+ duration: Maximum recording duration in seconds
740
+ """
741
+ print(f"Recording episode: {task_description}")
742
+ print("Press ENTER to start recording...")
743
+ input()
744
+
745
+ episode_data = {
746
+ "observation.image": [],
747
+ "observation.state": [],
748
+ "action": [],
749
+ "timestamp": [],
750
+ "task": task_description
751
+ }
752
+
753
+ start_time = time.time()
754
+ frame_count = 0
755
+
756
+ print("Recording... Press CTRL+C to stop")
757
+
758
+ try:
759
+ while time.time() - start_time < duration:
760
+ # Capture observation
761
+ image = camera.read()
762
+ state = mortis_arm.robot.get_state()
763
+
764
+ # Record current state as "action" (for behavior cloning)
765
+ action = state.copy()
766
+
767
+ # Store data
768
+ episode_data["observation.image"].append(image)
769
+ episode_data["observation.state"].append(state)
770
+ episode_data["action"].append(action)
771
+ episode_data["timestamp"].append(time.time() - start_time)
772
+
773
+ frame_count += 1
774
+ time.sleep(1/30) # 30 FPS
775
+
776
+ except KeyboardInterrupt:
777
+ print(f"\nRecording stopped. Captured {frame_count} frames")
778
+
779
+ # Save episode to dataset
780
+ self._save_episode(episode_data)
781
+
782
+ print(f"Episode saved: {task_description}")
783
+
784
+ def _save_episode(self, episode_data: dict):
785
+ """Save episode to LeRobot dataset."""
786
+ episode_index = len(self.dataset)
787
+
788
+ # Convert to numpy arrays
789
+ images = np.array(episode_data["observation.image"])
790
+ states = np.array(episode_data["observation.state"])
791
+ actions = np.array(episode_data["action"])
792
+
793
+ # Add to dataset
794
+ self.dataset.add_episode({
795
+ "observation.image": images,
796
+ "observation.state": states,
797
+ "action": actions,
798
+ "episode_index": episode_index,
799
+ "task": episode_data["task"]
800
+ })
801
+
802
+ # Save to disk
803
+ self.dataset.save_to_disk(self.dataset_dir)
804
+
805
+ def push_to_hub(self):
806
+ """Upload dataset to Hugging Face Hub."""
807
+ self.dataset.push_to_hub(self.repo_id)
808
+ print(f"Dataset pushed to: https://huggingface.co/datasets/{self.repo_id}")
809
+
810
+ # Usage script
811
+ def collect_demonstrations():
812
+ collector = DataCollector(
813
+ dataset_name="mortis_manipulation",
814
+ repo_id="your-username/mortis-manipulation"
815
+ )
816
+
817
+ tasks = [
818
+ "Pick up the skull and place it in the green cup",
819
+ "Pick up the skull and place it in the orange cup",
820
+ "Pick up the skull and place it in the purple cup",
821
+ "Pick up the eyeball and place it in the green cup",
822
+ "Pick up the eyeball and place it in the orange cup",
823
+ "Pick up the eyeball and place it in the purple cup",
824
+ ]
825
+
826
+ for task in tasks:
827
+ print(f"\n{'='*60}")
828
+ print(f"Task: {task}")
829
+ print(f"{'='*60}")
830
+
831
+ # Record multiple demonstrations per task
832
+ for demo_num in range(5):
833
+ print(f"\nDemonstration {demo_num + 1}/5")
834
+ collector.record_episode(task)
835
+
836
+ # Upload to Hugging Face
837
+ collector.push_to_hub()
838
+ ```
839
+
840
+
841
+ ### 8. Training Pipeline
842
+
843
+ #### Component: `TrainingPipeline`
844
+
845
+ **Purpose:** Train SmolVLA model on collected demonstration data.
846
+
847
+ **LeRobot Training Configuration:**
848
+
849
+ ```yaml
850
+ # config/train_smolvla.yaml
851
+ defaults:
852
+ - _self_
853
+ - policy: smolvla
854
+
855
+ seed: 1000
856
+ dataset_repo_id: your-username/mortis-manipulation
857
+ video_backend: pyav
858
+
859
+ training:
860
+ offline_steps: 100000
861
+ online_steps: 0
862
+ eval_freq: 10000
863
+ save_freq: 10000
864
+ log_freq: 100
865
+ save_checkpoint: true
866
+
867
+ batch_size: 8
868
+ lr: 1e-4
869
+ lr_scheduler: cosine
870
+ lr_warmup_steps: 1000
871
+ adam_betas: [0.9, 0.999]
872
+ adam_weight_decay: 1e-6
873
+ grad_clip_norm: 10.0
874
+
875
+ delta_timestamps:
876
+ action: "[i / ${fps} for i in range(${policy.chunk_size})]"
877
+
878
+ eval:
879
+ n_episodes: 10
880
+ batch_size: 10
881
+
882
+ policy:
883
+ name: smolvla
884
+
885
+ # Input dimensions
886
+ input_shapes:
887
+ observation.image: [3, 224, 224]
888
+ observation.state: [6] # 6 joints
889
+
890
+ output_shapes:
891
+ action: [6] # 6 joint commands
892
+
893
+ # Model architecture
894
+ vision_backbone: "google/siglip-so400m-patch14-384"
895
+ pretrained_backbone_weights: "google/siglip-so400m-patch14-384"
896
+
897
+ # Action prediction
898
+ chunk_size: 50 # Predict 50 steps ahead
899
+ n_action_steps: 50
900
+
901
+ # Training
902
+ use_language_conditioning: true
903
+ dropout: 0.1
904
+
905
+ device: cuda
906
+ use_amp: true # Automatic mixed precision
907
+ ```
908
+
909
+ **Training Script:**
910
+
911
+ ```python
912
+ from lerobot.scripts.train import train
913
+ from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
914
+ from pathlib import Path
915
+ import hydra
916
+ from omegaconf import DictConfig
917
+
918
+ @hydra.main(config_path="config", config_name="train_smolvla", version_base="1.2")
919
+ def train_smolvla(cfg: DictConfig):
920
+ """
921
+ Train SmolVLA model using LeRobot training pipeline.
922
+
923
+ Usage:
924
+ python -m mortis.train
925
+ """
926
+ # Load dataset
927
+ dataset = LeRobotDataset(
928
+ repo_id=cfg.dataset_repo_id,
929
+ split="train"
930
+ )
931
+
932
+ print(f"Dataset loaded: {len(dataset)} episodes")
933
+ print(f"Training for {cfg.training.offline_steps} steps")
934
+
935
+ # Run training
936
+ train(cfg)
937
+
938
+ print("Training complete!")
939
+ print(f"Checkpoints saved to: outputs/train/{cfg.run_name}")
940
+
941
+ if __name__ == "__main__":
942
+ train_smolvla()
943
+ ```
944
+
945
+ **Simplified Training Command:**
946
+
947
+ ```bash
948
+ # Using lerobot CLI
949
+ python -m lerobot.scripts.train \
950
+ policy=smolvla \
951
+ env=so101 \
952
+ dataset_repo_id=your-username/mortis-manipulation \
953
+ training.offline_steps=100000 \
954
+ training.batch_size=8 \
955
+ training.save_freq=10000 \
956
+ device=cuda \
957
+ wandb.enable=true \
958
+ wandb.project=mortis-smolvla
959
+ ```
960
+
961
+ **Training Monitoring:**
962
+
963
+ ```python
964
+ # Integration with Weights & Biases for tracking
965
+ import wandb
966
+
967
+ wandb.init(
968
+ project="mortis-smolvla",
969
+ config={
970
+ "dataset": "mortis-manipulation",
971
+ "policy": "smolvla",
972
+ "batch_size": 8,
973
+ "learning_rate": 1e-4
974
+ }
975
+ )
976
+
977
+ # Logged automatically by LeRobot:
978
+ # - Training loss
979
+ # - Validation loss
980
+ # - Action prediction accuracy
981
+ # - Episode success rate
982
+ # - Sample predictions (videos)
983
+ ```
984
+
985
+
986
+ ## Data Models
987
+
988
+ ### 1. Gemini Response Model
989
+
990
+ ```python
991
+ from dataclasses import dataclass
992
+ from enum import Enum
993
+ from typing import Optional
994
+
995
+ class ResponseType(Enum):
996
+ CONVERSATION = "conversation"
997
+ MANIPULATION = "manipulation"
998
+
999
+ class Mood(Enum):
1000
+ OMINOUS = "ominous"
1001
+ PLAYFUL = "playful"
1002
+ ANGRY = "angry"
1003
+ NERVOUS = "nervous"
1004
+ TRIUMPHANT = "triumphant"
1005
+ MISCHIEVOUS = "mischievous"
1006
+ SINISTER = "sinister"
1007
+ CURIOUS = "curious"
1008
+ NEUTRAL = "neutral"
1009
+
1010
+ class Gesture(Enum):
1011
+ IDLE = "idle"
1012
+ WAVE = "wave"
1013
+ POINT_LEFT = "point_left"
1014
+ POINT_RIGHT = "point_right"
1015
+ GRAB = "grab"
1016
+ DROP = "drop"
1017
+
1018
+ @dataclass
1019
+ class GeminiResponse:
1020
+ """Structured response from Gemini API."""
1021
+ type: ResponseType
1022
+ message: str
1023
+ mood: Mood
1024
+ gesture: Optional[Gesture] = None
1025
+ command: Optional[str] = None
1026
+
1027
+ @classmethod
1028
+ def from_json(cls, data: dict) -> 'GeminiResponse':
1029
+ """Parse JSON response from Gemini."""
1030
+ response_type = ResponseType(data["type"])
1031
+
1032
+ if response_type == ResponseType.MANIPULATION:
1033
+ return cls(
1034
+ type=response_type,
1035
+ message=data["message"],
1036
+ mood=Mood(data["mood"]),
1037
+ command=data["command"]
1038
+ )
1039
+ else:
1040
+ return cls(
1041
+ type=response_type,
1042
+ message=data["message"],
1043
+ mood=Mood(data["mood"]),
1044
+ gesture=Gesture(data.get("gesture", "idle"))
1045
+ )
1046
+ ```
1047
+
1048
+ ### 2. Task Execution Model
1049
+
1050
+ ```python
1051
+ from dataclasses import dataclass
1052
+ from enum import Enum
1053
+ from typing import Optional
1054
+ import time
1055
+
1056
+ class TaskStatus(Enum):
1057
+ QUEUED = "queued"
1058
+ RUNNING = "running"
1059
+ COMPLETE = "complete"
1060
+ FAILED = "failed"
1061
+
1062
+ class TaskType(Enum):
1063
+ GESTURE = "gesture"
1064
+ MANIPULATION = "manipulation"
1065
+
1066
+ @dataclass
1067
+ class Task:
1068
+ """Represents a robot task for execution."""
1069
+ id: str
1070
+ type: TaskType
1071
+ status: TaskStatus
1072
+ created_at: float
1073
+ started_at: Optional[float] = None
1074
+ completed_at: Optional[float] = None
1075
+ error: Optional[str] = None
1076
+
1077
+ # Task-specific data
1078
+ gesture: Optional[str] = None
1079
+ command: Optional[str] = None
1080
+
1081
+ @classmethod
1082
+ def create_gesture_task(cls, gesture: str) -> 'Task':
1083
+ """Create a gesture execution task."""
1084
+ return cls(
1085
+ id=f"gesture_{time.time()}",
1086
+ type=TaskType.GESTURE,
1087
+ status=TaskStatus.QUEUED,
1088
+ created_at=time.time(),
1089
+ gesture=gesture
1090
+ )
1091
+
1092
+ @classmethod
1093
+ def create_manipulation_task(cls, command: str) -> 'Task':
1094
+ """Create a manipulation execution task."""
1095
+ return cls(
1096
+ id=f"manipulation_{time.time()}",
1097
+ type=TaskType.MANIPULATION,
1098
+ status=TaskStatus.QUEUED,
1099
+ created_at=time.time(),
1100
+ command=command
1101
+ )
1102
+
1103
+ def start(self):
1104
+ """Mark task as started."""
1105
+ self.status = TaskStatus.RUNNING
1106
+ self.started_at = time.time()
1107
+
1108
+ def complete(self):
1109
+ """Mark task as completed."""
1110
+ self.status = TaskStatus.COMPLETE
1111
+ self.completed_at = time.time()
1112
+
1113
+ def fail(self, error: str):
1114
+ """Mark task as failed."""
1115
+ self.status = TaskStatus.FAILED
1116
+ self.completed_at = time.time()
1117
+ self.error = error
1118
+
1119
+ @property
1120
+ def duration(self) -> Optional[float]:
1121
+ """Get task execution duration."""
1122
+ if self.started_at and self.completed_at:
1123
+ return self.completed_at - self.started_at
1124
+ return None
1125
+ ```
1126
+
1127
+ ### 3. Dataset Episode Model
1128
+
1129
+ ```python
1130
+ from dataclasses import dataclass
1131
+ import numpy as np
1132
+ from typing import List
1133
+
1134
+ @dataclass
1135
+ class Episode:
1136
+ """Represents a single demonstration episode."""
1137
+ episode_index: int
1138
+ task_description: str
1139
+ images: np.ndarray # Shape: (T, H, W, 3)
1140
+ states: np.ndarray # Shape: (T, 6)
1141
+ actions: np.ndarray # Shape: (T, 6)
1142
+ timestamps: np.ndarray # Shape: (T,)
1143
+
1144
+ @property
1145
+ def length(self) -> int:
1146
+ """Number of timesteps in episode."""
1147
+ return len(self.timestamps)
1148
+
1149
+ @property
1150
+ def duration(self) -> float:
1151
+ """Episode duration in seconds."""
1152
+ return self.timestamps[-1] - self.timestamps[0]
1153
+
1154
+ def validate(self) -> bool:
1155
+ """Validate episode data consistency."""
1156
+ lengths = [
1157
+ len(self.images),
1158
+ len(self.states),
1159
+ len(self.actions),
1160
+ len(self.timestamps)
1161
+ ]
1162
+ return len(set(lengths)) == 1 # All same length
1163
+ ```
1164
+
1165
+
1166
+ ## Error Handling
1167
+
1168
+ ### Error Categories and Recovery Strategies
1169
+
1170
+ #### 1. Gemini API Errors
1171
+
1172
+ **Error Types:**
1173
+ - Authentication failures (invalid API key)
1174
+ - Rate limiting (quota exceeded)
1175
+ - Network timeouts
1176
+ - Invalid responses (malformed JSON)
1177
+
1178
+ **Recovery Strategy:**
1179
+
1180
+ ```python
1181
+ import time
1182
+ from typing import Optional
1183
+
1184
+ class GeminiAPIError(Exception):
1185
+ """Base exception for Gemini API errors."""
1186
+ pass
1187
+
1188
+ class GeminiClient:
1189
+ def __init__(self, api_key: str, max_retries: int = 3):
1190
+ self.api_key = api_key
1191
+ self.max_retries = max_retries
1192
+
1193
+ def send_message_with_retry(
1194
+ self,
1195
+ message: str,
1196
+ retry_count: int = 0
1197
+ ) -> Optional[GeminiResponse]:
1198
+ """Send message with exponential backoff retry."""
1199
+ try:
1200
+ response = self._send_message(message)
1201
+ return response
1202
+
1203
+ except genai.types.BlockedPromptException as e:
1204
+ # Content safety filter triggered
1205
+ print(f"Prompt blocked by safety filter: {e}")
1206
+ return self._get_fallback_response()
1207
+
1208
+ except genai.types.RateLimitError as e:
1209
+ if retry_count < self.max_retries:
1210
+ wait_time = 2 ** retry_count # Exponential backoff
1211
+ print(f"Rate limited. Retrying in {wait_time}s...")
1212
+ time.sleep(wait_time)
1213
+ return self.send_message_with_retry(message, retry_count + 1)
1214
+ else:
1215
+ raise GeminiAPIError("Max retries exceeded for rate limit")
1216
+
1217
+ except Exception as e:
1218
+ print(f"Gemini API error: {e}")
1219
+ return self._get_fallback_response()
1220
+
1221
+ def _get_fallback_response(self) -> GeminiResponse:
1222
+ """Return safe fallback response on API failure."""
1223
+ return GeminiResponse(
1224
+ type=ResponseType.CONVERSATION,
1225
+ message="The spirits are restless... try again.",
1226
+ mood=Mood.OMINOUS,
1227
+ gesture=Gesture.IDLE
1228
+ )
1229
+ ```
1230
+
1231
+ #### 2. STT/TTS Errors
1232
+
1233
+ **Error Types:**
1234
+ - Audio format incompatibility
1235
+ - Service unavailable
1236
+ - Transcription failures (unclear audio)
1237
+
1238
+ **Recovery Strategy:**
1239
+
1240
+ ```python
1241
+ class AudioProcessingError(Exception):
1242
+ """Base exception for audio processing errors."""
1243
+ pass
1244
+
1245
+ def process_audio_with_fallback(audio_path: str) -> str:
1246
+ """Process audio with fallback to text input."""
1247
+ try:
1248
+ # Try Gemini native audio
1249
+ transcript = transcribe_with_gemini(audio_path)
1250
+ return transcript
1251
+
1252
+ except Exception as e:
1253
+ print(f"Gemini audio processing failed: {e}")
1254
+
1255
+ try:
1256
+ # Fallback to Google STT
1257
+ transcript = transcribe_with_google_stt(audio_path)
1258
+ return transcript
1259
+
1260
+ except Exception as e:
1261
+ print(f"Google STT failed: {e}")
1262
+ raise AudioProcessingError(
1263
+ "Could not process audio. Please use text input."
1264
+ )
1265
+
1266
+ def synthesize_speech_with_fallback(text: str) -> Optional[str]:
1267
+ """Synthesize speech with fallback to text-only."""
1268
+ try:
1269
+ audio_path = synthesize_with_google_tts(text)
1270
+ return audio_path
1271
+
1272
+ except Exception as e:
1273
+ print(f"TTS failed: {e}. Returning text only.")
1274
+ return None # UI will display text without audio
1275
+ ```
1276
+
1277
+ #### 3. SmolVLA Inference Errors
1278
+
1279
+ **Error Types:**
1280
+ - Model loading failures
1281
+ - GPU out of memory
1282
+ - Invalid observations
1283
+ - Action execution failures
1284
+
1285
+ **Recovery Strategy:**
1286
+
1287
+ ```python
1288
+ class SmolVLAError(Exception):
1289
+ """Base exception for SmolVLA errors."""
1290
+ pass
1291
+
1292
+ class SmolVLAExecutor:
1293
+ def execute_with_safety(self, command: str) -> bool:
1294
+ """Execute command with safety checks and recovery."""
1295
+ try:
1296
+ # Pre-execution validation
1297
+ if not self._validate_command(command):
1298
+ raise SmolVLAError(f"Invalid command: {command}")
1299
+
1300
+ if not self._check_workspace_clear():
1301
+ raise SmolVLAError("Workspace not clear. Remove obstacles.")
1302
+
1303
+ # Execute with timeout
1304
+ success = self._execute_with_timeout(command, timeout=60.0)
1305
+
1306
+ if not success:
1307
+ raise SmolVLAError("Execution timeout")
1308
+
1309
+ return True
1310
+
1311
+ except torch.cuda.OutOfMemoryError:
1312
+ print("GPU OOM. Clearing cache and retrying...")
1313
+ torch.cuda.empty_cache()
1314
+ return self._execute_with_timeout(command, timeout=60.0)
1315
+
1316
+ except Exception as e:
1317
+ print(f"SmolVLA execution failed: {e}")
1318
+ # Return to safe position
1319
+ self._emergency_stop()
1320
+ return False
1321
+
1322
+ def _emergency_stop(self):
1323
+ """Return robot to safe idle position."""
1324
+ print("Emergency stop: returning to idle position")
1325
+ mortis_arm.move_arm("idle")
1326
+
1327
+ def _validate_command(self, command: str) -> bool:
1328
+ """Validate command is in trained set."""
1329
+ return command in self.valid_commands
1330
+
1331
+ def _check_workspace_clear(self) -> bool:
1332
+ """Check if workspace is safe for execution."""
1333
+ # Could use computer vision to detect obstacles
1334
+ # For now, assume clear
1335
+ return True
1336
+ ```
1337
+
1338
+ #### 4. Robot Hardware Errors
1339
+
1340
+ **Error Types:**
1341
+ - Connection failures
1342
+ - Servo errors
1343
+ - Position limits exceeded
1344
+ - Communication timeouts
1345
+
1346
+ **Recovery Strategy:**
1347
+
1348
+ ```python
1349
+ class RobotError(Exception):
1350
+ """Base exception for robot hardware errors."""
1351
+ pass
1352
+
1353
+ class MortisArm:
1354
+ def move_arm_safe(self, gesture_name: str) -> bool:
1355
+ """Execute gesture with error handling."""
1356
+ if not self.connected:
1357
+ try:
1358
+ self.connect()
1359
+ except Exception as e:
1360
+ print(f"Failed to connect to robot: {e}")
1361
+ return False
1362
+
1363
+ try:
1364
+ self.move_arm(gesture_name)
1365
+ return True
1366
+
1367
+ except Exception as e:
1368
+ print(f"Gesture execution failed: {e}")
1369
+
1370
+ # Attempt recovery
1371
+ try:
1372
+ print("Attempting to reconnect...")
1373
+ self.disconnect()
1374
+ time.sleep(1)
1375
+ self.connect()
1376
+ self.move_arm("idle")
1377
+ return False
1378
+
1379
+ except Exception as e:
1380
+ print(f"Recovery failed: {e}")
1381
+ self.connected = False
1382
+ return False
1383
+ ```
1384
+
1385
+ ### Error Reporting to User
1386
+
1387
+ ```python
1388
+ def format_error_message(error: Exception) -> str:
1389
+ """Format error for user display."""
1390
+ error_messages = {
1391
+ GeminiAPIError: "🔮 The spirits are not responding. Please try again.",
1392
+ AudioProcessingError: "🎤 Could not understand audio. Please try text input.",
1393
+ SmolVLAError: "🤖 Mortis cannot perform that action right now.",
1394
+ RobotError: "⚠️ Robot connection lost. Attempting to reconnect...",
1395
+ }
1396
+
1397
+ error_type = type(error)
1398
+ return error_messages.get(error_type, "❌ An unexpected error occurred.")
1399
+ ```
1400
+
1401
+
1402
+ ## Testing Strategy
1403
+
1404
+ ### 1. Unit Testing
1405
+
1406
+ **Components to Test:**
1407
+ - Gemini API client (with mocked responses)
1408
+ - Intent router (parsing and validation)
1409
+ - Data models (serialization/deserialization)
1410
+ - Audio processing utilities
1411
+
1412
+ **Example Test:**
1413
+
1414
+ ```python
1415
+ import pytest
1416
+ from unittest.mock import Mock, patch
1417
+ from mortis.gemini_client import GeminiClient, GeminiResponse, ResponseType
1418
+
1419
+ def test_gemini_response_parsing():
1420
+ """Test parsing of Gemini JSON responses."""
1421
+ # Test conversation response
1422
+ conv_data = {
1423
+ "type": "conversation",
1424
+ "message": "Beware, mortal...",
1425
+ "mood": "ominous",
1426
+ "gesture": "wave"
1427
+ }
1428
+ response = GeminiResponse.from_json(conv_data)
1429
+ assert response.type == ResponseType.CONVERSATION
1430
+ assert response.gesture.value == "wave"
1431
+
1432
+ # Test manipulation response
1433
+ manip_data = {
1434
+ "type": "manipulation",
1435
+ "message": "As you wish...",
1436
+ "mood": "sinister",
1437
+ "command": "Pick up the skull and place it in the green cup"
1438
+ }
1439
+ response = GeminiResponse.from_json(manip_data)
1440
+ assert response.type == ResponseType.MANIPULATION
1441
+ assert response.command is not None
1442
+
1443
+ @patch('google.generativeai.GenerativeModel')
1444
+ def test_gemini_client_retry(mock_model):
1445
+ """Test retry logic for API failures."""
1446
+ client = GeminiClient(api_key="test_key", max_retries=3)
1447
+
1448
+ # Simulate rate limit error then success
1449
+ mock_model.return_value.generate_content.side_effect = [
1450
+ genai.types.RateLimitError("Rate limited"),
1451
+ Mock(text='{"type": "conversation", "message": "Hello", "mood": "neutral", "gesture": "idle"}')
1452
+ ]
1453
+
1454
+ response = client.send_message_with_retry("Hello")
1455
+ assert response is not None
1456
+ assert mock_model.return_value.generate_content.call_count == 2
1457
+ ```
1458
+
1459
+ ### 2. Integration Testing
1460
+
1461
+ **Test Scenarios:**
1462
+ - End-to-end voice input → Gemini → gesture execution
1463
+ - Text input → intent detection → SmolVLA execution
1464
+ - Dataset collection → training → inference pipeline
1465
+ - Error recovery flows
1466
+
1467
+ **Example Test:**
1468
+
1469
+ ```python
1470
+ @pytest.mark.integration
1471
+ def test_voice_to_gesture_flow():
1472
+ """Test complete voice input to gesture execution."""
1473
+ # Record test audio
1474
+ test_audio = "tests/fixtures/test_wave.wav"
1475
+
1476
+ # Process audio
1477
+ transcript = process_audio(test_audio)
1478
+ assert "wave" in transcript.lower()
1479
+
1480
+ # Send to Gemini
1481
+ response = gemini_client.send_message(transcript)
1482
+ assert response.type == ResponseType.CONVERSATION
1483
+ assert response.gesture == Gesture.WAVE
1484
+
1485
+ # Execute gesture (with mock robot)
1486
+ with patch.object(mortis_arm, 'move_arm') as mock_move:
1487
+ execute_gesture(response.gesture)
1488
+ mock_move.assert_called_once_with("wave")
1489
+
1490
+ @pytest.mark.integration
1491
+ @pytest.mark.slow
1492
+ def test_smolvla_inference():
1493
+ """Test SmolVLA model inference (requires GPU)."""
1494
+ if not torch.cuda.is_available():
1495
+ pytest.skip("GPU not available")
1496
+
1497
+ # Load test checkpoint
1498
+ executor = SmolVLAExecutor("tests/fixtures/test_checkpoint")
1499
+
1500
+ # Execute test command
1501
+ command = "Pick up the skull and place it in the green cup"
1502
+ success = executor.execute(command, max_steps=10)
1503
+
1504
+ assert success
1505
+ ```
1506
+
1507
+ ### 3. System Testing
1508
+
1509
+ **Test Scenarios:**
1510
+ - Multi-user concurrent access
1511
+ - Long-running operation stability
1512
+ - Resource usage (GPU memory, CPU)
1513
+ - Network failure recovery
1514
+
1515
+ **Performance Benchmarks:**
1516
+
1517
+ ```python
1518
+ @pytest.mark.benchmark
1519
+ def test_gemini_response_time():
1520
+ """Benchmark Gemini API response time."""
1521
+ import time
1522
+
1523
+ times = []
1524
+ for _ in range(10):
1525
+ start = time.time()
1526
+ response = gemini_client.send_message("Hello Mortis")
1527
+ elapsed = time.time() - start
1528
+ times.append(elapsed)
1529
+
1530
+ avg_time = sum(times) / len(times)
1531
+ assert avg_time < 2.0, f"Average response time {avg_time}s exceeds 2s threshold"
1532
+
1533
+ @pytest.mark.benchmark
1534
+ def test_smolvla_inference_time():
1535
+ """Benchmark SmolVLA inference speed."""
1536
+ executor = SmolVLAExecutor("checkpoints/best_model")
1537
+
1538
+ start = time.time()
1539
+ executor.execute("Pick up the skull and place it in the green cup", max_steps=100)
1540
+ elapsed = time.time() - start
1541
+
1542
+ assert elapsed < 30.0, f"Inference time {elapsed}s exceeds 30s threshold"
1543
+ ```
1544
+
1545
+ ### 4. User Acceptance Testing
1546
+
1547
+ **Test Scenarios:**
1548
+ - Voice recognition accuracy with different accents
1549
+ - Task success rate for manipulation commands
1550
+ - UI responsiveness during long operations
1551
+ - Error message clarity and helpfulness
1552
+
1553
+ **Manual Test Checklist:**
1554
+
1555
+ ```markdown
1556
+ ## Voice Input Testing
1557
+ - [ ] Clear speech recognized correctly
1558
+ - [ ] Background noise handled gracefully
1559
+ - [ ] Multiple languages supported (if applicable)
1560
+ - [ ] Audio feedback provided to user
1561
+
1562
+ ## Manipulation Task Testing
1563
+ - [ ] All 6 trained tasks execute successfully
1564
+ - [ ] Task variations handled appropriately
1565
+ - [ ] Robot returns to safe position after completion
1566
+ - [ ] Visual feedback clear during execution
1567
+
1568
+ ## Error Handling Testing
1569
+ - [ ] API failures display helpful messages
1570
+ - [ ] Robot errors trigger safe shutdown
1571
+ - [ ] Network issues handled gracefully
1572
+ - [ ] Recovery procedures work as expected
1573
+
1574
+ ## UI/UX Testing
1575
+ - [ ] Interface remains responsive during tasks
1576
+ - [ ] Status updates clear and timely
1577
+ - [ ] Audio playback works correctly
1578
+ - [ ] Webcam feed displays properly
1579
+ ```
1580
+
1581
+ ### 5. Safety Testing
1582
+
1583
+ **Critical Safety Tests:**
1584
+
1585
+ ```python
1586
+ def test_emergency_stop():
1587
+ """Test emergency stop functionality."""
1588
+ executor = SmolVLAExecutor("checkpoints/best_model")
1589
+
1590
+ # Start execution
1591
+ task_thread = Thread(target=executor.execute, args=("test command",))
1592
+ task_thread.start()
1593
+
1594
+ # Trigger emergency stop
1595
+ time.sleep(1)
1596
+ executor._emergency_stop()
1597
+
1598
+ # Verify robot in safe position
1599
+ state = mortis_arm.robot.get_state()
1600
+ assert state == HOME_POSE
1601
+
1602
+ def test_workspace_collision_detection():
1603
+ """Test collision detection and avoidance."""
1604
+ # Place obstacle in workspace
1605
+ # Attempt manipulation task
1606
+ # Verify task aborted safely
1607
+ pass
1608
+ ```
1609
+
1610
+
1611
+ ## Deployment and Configuration
1612
+
1613
+ ### Environment Configuration
1614
+
1615
+ **Required Environment Variables:**
1616
+
1617
+ ```bash
1618
+ # .env file
1619
+ # Gemini API
1620
+ GEMINI_API_KEY=your_google_api_key
1621
+ GEMINI_MODEL=gemini-2.0-flash-exp
1622
+ GEMINI_TEMPERATURE=0.2
1623
+
1624
+ # Google Cloud (for STT/TTS if not using Gemini native)
1625
+ GOOGLE_APPLICATION_CREDENTIALS=/path/to/service-account-key.json
1626
+
1627
+ # Robot Configuration
1628
+ ROBOT_PORT=/dev/ttyACM1
1629
+ ROBOT_CALIBRATION_DIR=.cache/calibration/so101/
1630
+
1631
+ # SmolVLA Model
1632
+ SMOLVLA_CHECKPOINT_PATH=checkpoints/smolvla_best.pt
1633
+ SMOLVLA_DEVICE=cuda
1634
+
1635
+ # Application
1636
+ PORT=7860
1637
+ DEBUG=false
1638
+
1639
+ # Optional: Weights & Biases for training
1640
+ WANDB_API_KEY=your_wandb_key
1641
+ WANDB_PROJECT=mortis-smolvla
1642
+ ```
1643
+
1644
+ ### Dependency Management
1645
+
1646
+ **Updated pyproject.toml:**
1647
+
1648
+ ```toml
1649
+ [project]
1650
+ name = "mortis"
1651
+ version = "0.2.0"
1652
+ description = "Mortis: Multi-modal AI Halloween Experience with SmolVLA"
1653
+ requires-python = ">=3.12"
1654
+ dependencies = [
1655
+ "gradio>=5.49.1",
1656
+ "lerobot[async,feetech,intelrealsense,smolvla]>=0.4.0",
1657
+ "python-dotenv>=1.2.1",
1658
+
1659
+ # Gemini and Google Cloud
1660
+ "google-generativeai>=0.8.0",
1661
+ "google-cloud-speech>=2.26.0",
1662
+ "google-cloud-texttospeech>=2.16.0",
1663
+
1664
+ # ML and Vision
1665
+ "torch>=2.0.0",
1666
+ "torchvision>=0.15.0",
1667
+ "transformers>=4.40.0",
1668
+ "pillow>=10.0.0",
1669
+
1670
+ # Data and utilities
1671
+ "numpy>=1.24.0",
1672
+ "opencv-python>=4.8.0",
1673
+ "datasets>=2.14.0",
1674
+ ]
1675
+
1676
+ [project.optional-dependencies]
1677
+ dev = [
1678
+ "pytest>=7.4.0",
1679
+ "pytest-asyncio>=0.21.0",
1680
+ "pytest-benchmark>=4.0.0",
1681
+ "black>=23.0.0",
1682
+ "ruff>=0.1.0",
1683
+ ]
1684
+
1685
+ training = [
1686
+ "wandb>=0.16.0",
1687
+ "hydra-core>=1.3.0",
1688
+ "tensorboard>=2.14.0",
1689
+ ]
1690
+
1691
+ [project.scripts]
1692
+ mortis = "mortis.app:main"
1693
+ calibrate = "mortis.calibrate:main"
1694
+ collect-data = "mortis.collect_data:main"
1695
+ train-smolvla = "mortis.train:main"
1696
+ ```
1697
+
1698
+ ### Installation Steps
1699
+
1700
+ ```bash
1701
+ # 1. Clone repository
1702
+ git clone https://github.com/your-username/mortis.git
1703
+ cd mortis
1704
+
1705
+ # 2. Install dependencies
1706
+ make install
1707
+
1708
+ # 3. Configure environment
1709
+ cp .env.example .env
1710
+ # Edit .env with your API keys
1711
+
1712
+ # 4. Calibrate robot (first time only)
1713
+ make calibrate
1714
+
1715
+ # 5. Download or train SmolVLA model
1716
+ # Option A: Download pre-trained model
1717
+ python -m mortis.download_model --checkpoint smolvla_mortis_v1
1718
+
1719
+ # Option B: Train from scratch
1720
+ make collect-data
1721
+ make train-smolvla
1722
+
1723
+ # 6. Run application
1724
+ make run
1725
+ ```
1726
+
1727
+ ### Docker Deployment (Optional)
1728
+
1729
+ **Dockerfile:**
1730
+
1731
+ ```dockerfile
1732
+ FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
1733
+
1734
+ # Install Python and system dependencies
1735
+ RUN apt-get update && apt-get install -y \
1736
+ python3.12 \
1737
+ python3-pip \
1738
+ libusb-1.0-0 \
1739
+ udev \
1740
+ && rm -rf /var/lib/apt/lists/*
1741
+
1742
+ # Install uv package manager
1743
+ RUN pip install uv
1744
+
1745
+ WORKDIR /app
1746
+
1747
+ # Copy project files
1748
+ COPY pyproject.toml uv.lock ./
1749
+ COPY src/ ./src/
1750
+ COPY assets/ ./assets/
1751
+
1752
+ # Install dependencies
1753
+ RUN uv sync --frozen
1754
+
1755
+ # Expose Gradio port
1756
+ EXPOSE 7860
1757
+
1758
+ # Run application
1759
+ CMD ["uv", "run", "mortis"]
1760
+ ```
1761
+
1762
+ **docker-compose.yml:**
1763
+
1764
+ ```yaml
1765
+ version: '3.8'
1766
+
1767
+ services:
1768
+ mortis:
1769
+ build: .
1770
+ ports:
1771
+ - "7860:7860"
1772
+ devices:
1773
+ - /dev/ttyACM1:/dev/ttyACM1 # Robot USB connection
1774
+ volumes:
1775
+ - ./.env:/app/.env
1776
+ - ./checkpoints:/app/checkpoints
1777
+ - ./.cache:/app/.cache
1778
+ environment:
1779
+ - NVIDIA_VISIBLE_DEVICES=all
1780
+ runtime: nvidia
1781
+ restart: unless-stopped
1782
+ ```
1783
+
1784
+ ### System Requirements
1785
+
1786
+ **Minimum Requirements:**
1787
+ - CPU: 4 cores
1788
+ - RAM: 16 GB
1789
+ - GPU: NVIDIA GPU with 8GB VRAM (for SmolVLA inference)
1790
+ - Storage: 50 GB (for models and datasets)
1791
+ - OS: Ubuntu 22.04 or later
1792
+ - USB: Available port for SO101 robot
1793
+
1794
+ **Recommended Requirements:**
1795
+ - CPU: 8+ cores
1796
+ - RAM: 32 GB
1797
+ - GPU: NVIDIA RTX 3090 or better (24GB VRAM)
1798
+ - Storage: 100 GB SSD
1799
+ - Network: Stable internet for Gemini API
1800
+
1801
+ ### Monitoring and Logging
1802
+
1803
+ **Logging Configuration:**
1804
+
1805
+ ```python
1806
+ import logging
1807
+ from pathlib import Path
1808
+
1809
+ # Configure logging
1810
+ LOG_DIR = Path("logs")
1811
+ LOG_DIR.mkdir(exist_ok=True)
1812
+
1813
+ logging.basicConfig(
1814
+ level=logging.INFO,
1815
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
1816
+ handlers=[
1817
+ logging.FileHandler(LOG_DIR / f"mortis_{time.time()}.log"),
1818
+ logging.StreamHandler()
1819
+ ]
1820
+ )
1821
+
1822
+ logger = logging.getLogger("mortis")
1823
+
1824
+ # Log important events
1825
+ logger.info("Application started")
1826
+ logger.info(f"Gemini model: {GEMINI_MODEL}")
1827
+ logger.info(f"SmolVLA checkpoint: {SMOLVLA_CHECKPOINT_PATH}")
1828
+ ```
1829
+
1830
+ **Metrics to Monitor:**
1831
+ - Gemini API response times
1832
+ - SmolVLA inference times
1833
+ - Task success rates
1834
+ - Error frequencies
1835
+ - GPU memory usage
1836
+ - Robot connection status
1837
+
1838
+
1839
+ ## Migration Strategy
1840
+
1841
+ ### Phase 1: Gemini API Integration (Week 1)
1842
+
1843
+ **Goals:**
1844
+ - Replace existing LLM API with Gemini
1845
+ - Maintain current gesture functionality
1846
+ - Add structured JSON response parsing
1847
+
1848
+ **Tasks:**
1849
+ 1. Create `GeminiClient` class
1850
+ 2. Update system prompt for Gemini
1851
+ 3. Modify `ask_mortis()` to use Gemini API
1852
+ 4. Test with existing gestures
1853
+ 5. Update environment configuration
1854
+
1855
+ **Validation:**
1856
+ - All existing gestures work with Gemini
1857
+ - Response times comparable to previous API
1858
+ - Character personality maintained
1859
+
1860
+ ### Phase 2: Voice Input/Output (Week 2)
1861
+
1862
+ **Goals:**
1863
+ - Add audio input component to Gradio
1864
+ - Implement STT using Gemini native audio or Google STT
1865
+ - Add TTS for voice responses
1866
+ - Test multi-modal interaction
1867
+
1868
+ **Tasks:**
1869
+ 1. Add audio input/output components to UI
1870
+ 2. Implement STT service
1871
+ 3. Implement TTS service
1872
+ 4. Update UI to handle audio flows
1873
+ 5. Test voice interaction end-to-end
1874
+
1875
+ **Validation:**
1876
+ - Voice input transcribed accurately
1877
+ - Audio responses play correctly
1878
+ - Text input still works
1879
+ - UI remains responsive
1880
+
1881
+ ### Phase 3: Dataset Collection (Week 3)
1882
+
1883
+ **Goals:**
1884
+ - Set up data collection infrastructure
1885
+ - Record demonstrations for all 6 tasks
1886
+ - Validate and upload dataset to Hugging Face
1887
+
1888
+ **Tasks:**
1889
+ 1. Create `DataCollector` class
1890
+ 2. Set up camera and robot for recording
1891
+ 3. Record 5-10 demonstrations per task
1892
+ 4. Validate dataset quality
1893
+ 5. Push to Hugging Face Hub
1894
+
1895
+ **Validation:**
1896
+ - All 6 tasks have sufficient demonstrations
1897
+ - Data quality is high (clear images, smooth motions)
1898
+ - Dataset loads correctly in LeRobot
1899
+
1900
+ ### Phase 4: SmolVLA Training (Week 4)
1901
+
1902
+ **Goals:**
1903
+ - Train SmolVLA model on collected data
1904
+ - Evaluate model performance
1905
+ - Select best checkpoint
1906
+
1907
+ **Tasks:**
1908
+ 1. Configure training pipeline
1909
+ 2. Run training for 100k steps
1910
+ 3. Monitor training metrics
1911
+ 4. Evaluate on validation set
1912
+ 5. Select and save best checkpoint
1913
+
1914
+ **Validation:**
1915
+ - Training converges (loss decreases)
1916
+ - Validation performance acceptable
1917
+ - Model can execute at least 3/6 tasks successfully
1918
+
1919
+ ### Phase 5: Intent Detection and Routing (Week 5)
1920
+
1921
+ **Goals:**
1922
+ - Implement intent detection in Gemini prompt
1923
+ - Create intent router
1924
+ - Add command validation
1925
+
1926
+ **Tasks:**
1927
+ 1. Update Gemini system prompt with task definitions
1928
+ 2. Create `IntentRouter` class
1929
+ 3. Implement command validation
1930
+ 4. Test intent detection accuracy
1931
+ 5. Handle edge cases
1932
+
1933
+ **Validation:**
1934
+ - Manipulation commands detected correctly (>90% accuracy)
1935
+ - Conversational inputs routed to gestures
1936
+ - Invalid commands handled gracefully
1937
+
1938
+ ### Phase 6: Asynchronous Execution (Week 6)
1939
+
1940
+ **Goals:**
1941
+ - Implement async task execution
1942
+ - Add status tracking and UI updates
1943
+ - Test UI responsiveness
1944
+
1945
+ **Tasks:**
1946
+ 1. Create `AsyncExecutor` class
1947
+ 2. Implement task queue
1948
+ 3. Add status display to UI
1949
+ 4. Test with long-running tasks
1950
+ 5. Handle concurrent requests
1951
+
1952
+ **Validation:**
1953
+ - UI remains responsive during SmolVLA execution
1954
+ - Status updates appear correctly
1955
+ - Multiple tasks can be queued
1956
+ - Errors don't crash the system
1957
+
1958
+ ### Phase 7: Integration and Testing (Week 7)
1959
+
1960
+ **Goals:**
1961
+ - Integrate all components
1962
+ - Comprehensive testing
1963
+ - Bug fixes and optimization
1964
+
1965
+ **Tasks:**
1966
+ 1. Integration testing
1967
+ 2. Performance optimization
1968
+ 3. Error handling improvements
1969
+ 4. Documentation updates
1970
+ 5. User acceptance testing
1971
+
1972
+ **Validation:**
1973
+ - All features work together
1974
+ - Performance meets requirements
1975
+ - Error handling robust
1976
+ - Documentation complete
1977
+
1978
+ ### Phase 8: Deployment and Monitoring (Week 8)
1979
+
1980
+ **Goals:**
1981
+ - Deploy to production environment
1982
+ - Set up monitoring
1983
+ - Create user documentation
1984
+
1985
+ **Tasks:**
1986
+ 1. Prepare deployment environment
1987
+ 2. Configure monitoring and logging
1988
+ 3. Create user guide
1989
+ 4. Deploy application
1990
+ 5. Monitor initial usage
1991
+
1992
+ **Validation:**
1993
+ - Application runs stably
1994
+ - Monitoring captures key metrics
1995
+ - Users can operate system successfully
1996
+
1997
+ ### Rollback Plan
1998
+
1999
+ If critical issues arise during migration:
2000
+
2001
+ 1. **Immediate Rollback:**
2002
+ - Revert to previous LLM API
2003
+ - Disable voice features
2004
+ - Use gesture-only mode
2005
+
2006
+ 2. **Partial Rollback:**
2007
+ - Keep Gemini API
2008
+ - Disable SmolVLA (gestures only)
2009
+ - Disable voice features
2010
+
2011
+ 3. **Data Preservation:**
2012
+ - All datasets backed up to Hugging Face
2013
+ - Model checkpoints saved to cloud storage
2014
+ - Configuration files version controlled
2015
+
2016
+ ### Risk Mitigation
2017
+
2018
+ **Risk: Gemini API costs exceed budget**
2019
+ - Mitigation: Set API usage limits, implement caching, use smaller models
2020
+
2021
+ **Risk: SmolVLA training fails to converge**
2022
+ - Mitigation: Collect more data, adjust hyperparameters, use pre-trained weights
2023
+
2024
+ **Risk: Voice recognition accuracy too low**
2025
+ - Mitigation: Use better STT service, add noise filtering, provide text fallback
2026
+
2027
+ **Risk: GPU memory insufficient for SmolVLA**
2028
+ - Mitigation: Reduce batch size, use model quantization, upgrade hardware
2029
+
2030
+ **Risk: Robot safety issues during autonomous execution**
2031
+ - Mitigation: Implement workspace monitoring, add emergency stop, limit motion range
2032
+
2033
+
2034
+ ## Design Decisions and Rationale
2035
+
2036
+ ### 1. Why Gemini API over Other LLMs?
2037
+
2038
+ **Decision:** Use Google Gemini API as the primary LLM.
2039
+
2040
+ **Rationale:**
2041
+ - Native multi-modal support (audio, images, text)
2042
+ - Structured output via JSON mode
2043
+ - Strong intent detection capabilities
2044
+ - Integrated with Google Cloud ecosystem (STT/TTS)
2045
+ - Competitive pricing and performance
2046
+ - Good documentation and Python SDK
2047
+
2048
+ **Alternatives Considered:**
2049
+ - OpenAI GPT-4: More expensive, separate APIs for audio
2050
+ - Anthropic Claude: No native audio support
2051
+ - Local LLMs: Insufficient quality for intent detection
2052
+
2053
+ ### 2. Why asyncio.Queue over Redis?
2054
+
2055
+ **Decision:** Use Python's asyncio.Queue for task management.
2056
+
2057
+ **Rationale:**
2058
+ - Single-machine deployment (no distributed workers needed)
2059
+ - No external dependencies
2060
+ - Simpler implementation and debugging
2061
+ - Sufficient for expected load (single user at a time)
2062
+ - Lower latency than network-based queue
2063
+
2064
+ **When to Reconsider:**
2065
+ - Multiple robot arms
2066
+ - Distributed deployment
2067
+ - High concurrent user load
2068
+ - Need for task persistence across restarts
2069
+
2070
+ ### 3. Why SmolVLA over Other Robot Learning Approaches?
2071
+
2072
+ **Decision:** Use SmolVLA for manipulation tasks.
2073
+
2074
+ **Rationale:**
2075
+ - Vision-language-action model (understands natural language)
2076
+ - Integrated with LeRobot framework
2077
+ - End-to-end learning (no manual feature engineering)
2078
+ - Proven performance on manipulation tasks
2079
+ - Active development and community support
2080
+
2081
+ **Alternatives Considered:**
2082
+ - Reinforcement Learning: Requires extensive training, safety concerns
2083
+ - Classical Motion Planning: Requires manual programming, less flexible
2084
+ - Behavior Cloning (non-VLA): No language understanding
2085
+
2086
+ ### 4. Why Hybrid Gesture + SmolVLA Approach?
2087
+
2088
+ **Decision:** Keep predefined gestures for conversational responses, add SmolVLA for manipulation.
2089
+
2090
+ **Rationale:**
2091
+ - Gestures are fast and reliable (no inference needed)
2092
+ - SmolVLA reserved for complex manipulation tasks
2093
+ - Reduces GPU usage for simple interactions
2094
+ - Maintains backward compatibility
2095
+ - Clear separation of concerns
2096
+
2097
+ **Benefits:**
2098
+ - Lower latency for conversational interactions
2099
+ - More robust (gestures can't fail inference)
2100
+ - Better resource utilization
2101
+
2102
+ ### 5. Why Gradio over Custom Web Framework?
2103
+
2104
+ **Decision:** Continue using Gradio for the web interface.
2105
+
2106
+ **Rationale:**
2107
+ - Already integrated in existing system
2108
+ - Excellent support for audio/video components
2109
+ - Built-in WebSocket handling for real-time updates
2110
+ - Rapid prototyping and iteration
2111
+ - Good documentation and examples
2112
+
2113
+ **Limitations Acknowledged:**
2114
+ - Less customization than React/Vue
2115
+ - Limited styling options
2116
+ - Not ideal for production-scale applications
2117
+
2118
+ **When to Reconsider:**
2119
+ - Need for complex custom UI
2120
+ - Mobile app requirements
2121
+ - High-scale deployment (>100 concurrent users)
2122
+
2123
+ ### 6. Why Google TTS over Local Alternatives?
2124
+
2125
+ **Decision:** Use Google Cloud Text-to-Speech for voice output.
2126
+
2127
+ **Rationale:**
2128
+ - High-quality neural voices
2129
+ - Consistent with Gemini ecosystem
2130
+ - Low latency
2131
+ - Voice customization options (pitch, speed)
2132
+ - Reliable service
2133
+
2134
+ **Alternatives Considered:**
2135
+ - pyttsx3: Lower quality, robotic voice
2136
+ - gTTS: Limited voice options, requires internet anyway
2137
+ - Local neural TTS: High GPU usage, slower
2138
+
2139
+ ### 7. Why Separate Training and Inference Scripts?
2140
+
2141
+ **Decision:** Keep training infrastructure separate from runtime application.
2142
+
2143
+ **Rationale:**
2144
+ - Training is offline, one-time process
2145
+ - Different hardware requirements (training needs more VRAM)
2146
+ - Cleaner code organization
2147
+ - Easier to update training without affecting production
2148
+ - Can train on different machine than deployment
2149
+
2150
+ **Implementation:**
2151
+ - Training scripts in `mortis/train.py`
2152
+ - Inference in `mortis/smolvla_executor.py`
2153
+ - Shared model configuration
2154
+
2155
+ ### 8. Why Not Use Gemini for Robot Control Directly?
2156
+
2157
+ **Decision:** Use Gemini for intent detection, SmolVLA for action generation.
2158
+
2159
+ **Rationale:**
2160
+ - LLMs are not designed for precise motor control
2161
+ - SmolVLA trained specifically on robot demonstrations
2162
+ - Gemini would require extensive prompting for each action
2163
+ - SmolVLA provides closed-loop visual feedback
2164
+ - Separation of concerns (language understanding vs. motor control)
2165
+
2166
+ **Gemini's Role:**
2167
+ - Understand user intent
2168
+ - Detect manipulation commands
2169
+ - Generate conversational responses
2170
+ - Maintain character personality
2171
+
2172
+ **SmolVLA's Role:**
2173
+ - Generate precise robot actions
2174
+ - Process visual observations
2175
+ - Execute manipulation tasks
2176
+ - Handle low-level control
2177
+
2178
+ ### 9. Why Store Checkpoints Locally vs. Cloud?
2179
+
2180
+ **Decision:** Store model checkpoints locally with optional cloud backup.
2181
+
2182
+ **Rationale:**
2183
+ - Faster loading (no network latency)
2184
+ - No cloud storage costs during development
2185
+ - Privacy (model stays on local machine)
2186
+ - Simpler deployment
2187
+
2188
+ **Cloud Backup Strategy:**
2189
+ - Push final models to Hugging Face Hub
2190
+ - Version control with git-lfs
2191
+ - Disaster recovery
2192
+
2193
+ ### 10. Why 6 Specific Manipulation Tasks?
2194
+
2195
+ **Decision:** Start with 6 predefined manipulation tasks (skull/eyeball × 3 cups).
2196
+
2197
+ **Rationale:**
2198
+ - Manageable scope for initial implementation
2199
+ - Sufficient variety to demonstrate capability
2200
+ - Fits Halloween theme
2201
+ - Realistic data collection effort (30-60 demonstrations)
2202
+ - Can expand later with more tasks
2203
+
2204
+ **Expansion Path:**
2205
+ - Add more objects (pumpkin, spider, etc.)
2206
+ - Add more target locations
2207
+ - Add multi-step tasks
2208
+ - Add task composition
2209
+
2210
+
2211
+ ## Future Enhancements
2212
+
2213
+ ### Short-term (3-6 months)
2214
+
2215
+ 1. **Expanded Task Set**
2216
+ - Add 10-20 more manipulation tasks
2217
+ - Support task composition ("pick up skull, then eyeball")
2218
+ - Add multi-object interactions
2219
+
2220
+ 2. **Improved Voice Interaction**
2221
+ - Wake word detection ("Hey Mortis")
2222
+ - Continuous conversation mode
2223
+ - Voice activity detection
2224
+ - Speaker identification
2225
+
2226
+ 3. **Enhanced Safety**
2227
+ - Computer vision-based collision detection
2228
+ - Force/torque sensing
2229
+ - Workspace boundary enforcement
2230
+ - Automatic emergency stop
2231
+
2232
+ 4. **Performance Optimization**
2233
+ - Model quantization for faster inference
2234
+ - Action caching for repeated tasks
2235
+ - Parallel processing for multiple requests
2236
+ - GPU memory optimization
2237
+
2238
+ ### Medium-term (6-12 months)
2239
+
2240
+ 1. **Advanced Learning**
2241
+ - Online learning from corrections
2242
+ - Few-shot task learning
2243
+ - Transfer learning to new objects
2244
+ - Self-supervised improvement
2245
+
2246
+ 2. **Multi-Robot Support**
2247
+ - Control multiple SO101 arms
2248
+ - Coordinated multi-arm tasks
2249
+ - Load balancing across robots
2250
+ - Distributed task execution
2251
+
2252
+ 3. **Enhanced Perception**
2253
+ - 3D object detection
2254
+ - Depth estimation
2255
+ - Object tracking
2256
+ - Scene understanding
2257
+
2258
+ 4. **User Personalization**
2259
+ - User profiles and preferences
2260
+ - Adaptive difficulty
2261
+ - Custom task definitions
2262
+ - Voice profile learning
2263
+
2264
+ ### Long-term (12+ months)
2265
+
2266
+ 1. **Autonomous Task Planning**
2267
+ - High-level goal specification
2268
+ - Automatic task decomposition
2269
+ - Multi-step planning
2270
+ - Failure recovery strategies
2271
+
2272
+ 2. **Natural Language Programming**
2273
+ - Teach new tasks through conversation
2274
+ - Automatic demonstration collection
2275
+ - Interactive refinement
2276
+ - Task library management
2277
+
2278
+ 3. **Advanced Interaction**
2279
+ - Gesture recognition (human gestures)
2280
+ - Facial expression detection
2281
+ - Emotion-aware responses
2282
+ - Proactive assistance
2283
+
2284
+ 4. **Production Deployment**
2285
+ - Multi-user support
2286
+ - Cloud-based inference
2287
+ - Mobile app interface
2288
+ - API for third-party integration
2289
+
2290
+ ## Conclusion
2291
+
2292
+ This design provides a comprehensive architecture for refactoring Mortis into a multi-modal, SmolVLA-powered robotic system. The design emphasizes:
2293
+
2294
+ - **Modularity:** Clear separation between components (Gemini, STT/TTS, SmolVLA, robot control)
2295
+ - **Scalability:** Asynchronous execution and queue-based architecture
2296
+ - **Reliability:** Comprehensive error handling and recovery strategies
2297
+ - **Maintainability:** Well-defined interfaces and data models
2298
+ - **Extensibility:** Clear paths for future enhancements
2299
+
2300
+ The phased migration strategy allows for incremental development and validation, reducing risk and enabling early feedback. The hybrid approach of combining predefined gestures with learned manipulation behaviors provides both reliability and flexibility.
2301
+
2302
+ Key technical decisions prioritize:
2303
+ - Google ecosystem integration (Gemini, Cloud STT/TTS)
2304
+ - Local deployment with GPU support
2305
+ - LeRobot framework for robotics
2306
+ - Gradio for rapid UI development
2307
+ - Python-native solutions (asyncio, threading)
2308
+
2309
+ The design is ready for implementation following the task list in the next phase of the spec workflow.
.kiro/specs/gemini-multimodal-refactor/requirements.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Requirements Document
2
+
3
+ ## Introduction
4
+
5
+ This document specifies the requirements for refactoring the Mortis interactive AI Halloween experience to use Google Gemini API with multi-modal (voice and text) interaction capabilities. The refactor replaces the existing LLM API integration and adds SmolVLA-based robotic control for specific manipulation tasks. The system must maintain the character-driven conversational experience while enabling precise robotic manipulation through voice or text commands.
6
+
7
+ ## Glossary
8
+
9
+ - **Mortis System**: The complete interactive AI Halloween experience including web UI, conversational AI, and robotic arm control
10
+ - **Gemini API**: Google's large language model API service used for conversational AI and intent detection
11
+ - **SmolVLA Model**: A vision-language-action model trained using LeRobot for specific robotic manipulation tasks
12
+ - **Gradio Interface**: The web-based user interface framework for the Mortis System
13
+ - **SO101 Arm**: The SeeedStudio SO101 robotic arm hardware controlled by the Mortis System
14
+ - **STT Service**: Speech-to-Text service that converts audio input to text
15
+ - **TTS Service**: Text-to-Speech service that converts text responses to audio output
16
+ - **Task String**: A specific command format recognized by the SmolVLA Model (e.g., "Pick up the skull and place it in the green cup")
17
+ - **LeRobot Framework**: The robotics framework used for dataset management, model training, and inference
18
+ - **Message Queue**: An asynchronous communication mechanism for decoupling robotic execution from the web interface
19
+ - **Cloud-Agnostic Architecture**: A system design that does not depend on vendor-specific cloud platform services (like AWS Lambda, Azure Functions, or GCP Cloud Run), allowing deployment on any infrastructure including local hardware
20
+
21
+ ## Requirements
22
+
23
+ ### Requirement 1: Gemini API Integration
24
+
25
+ **User Story:** As a developer, I want to replace the existing LLM API with Google Gemini API, so that the system uses Google's language model for all conversational interactions.
26
+
27
+ #### Acceptance Criteria
28
+
29
+ 1. THE Mortis System SHALL use the Google Gemini API for all language model interactions
30
+ 2. THE Mortis System SHALL support multiple Gemini model variants through configuration
31
+ 3. THE Mortis System SHALL authenticate with the Gemini API using API keys stored in environment variables
32
+ 4. THE Mortis System SHALL handle Gemini API errors gracefully and provide user feedback when API calls fail
33
+ 5. THE Mortis System SHALL maintain response times under 5 seconds for typical conversational interactions
34
+
35
+ ### Requirement 2: Multi-Modal Voice Input
36
+
37
+ **User Story:** As a user, I want to speak to Mortis through my microphone, so that I can interact naturally without typing.
38
+
39
+ #### Acceptance Criteria
40
+
41
+ 1. THE Gradio Interface SHALL provide an audio input component for capturing user voice
42
+ 2. WHEN a user provides voice input, THE Mortis System SHALL convert the audio to text using a Speech-to-Text service
43
+ 3. THE Mortis System SHALL support both cloud-based STT services and local STT models as configurable options
44
+ 4. THE Mortis System SHALL process voice input with latency under 3 seconds for utterances under 10 seconds
45
+ 5. THE Mortis System SHALL display the transcribed text to the user for confirmation
46
+
47
+ ### Requirement 3: Intent Detection and Command Routing
48
+
49
+ **User Story:** As a system, I want to detect when user input matches a specific robotic task command, so that I can route the request to the appropriate control mechanism.
50
+
51
+ #### Acceptance Criteria
52
+
53
+ 1. THE Gemini API SHALL receive a system prompt that defines all valid SmolVLA Task Strings
54
+ 2. WHEN the Gemini API processes user input, THE Mortis System SHALL determine if the input matches a valid Task String
55
+ 3. IF the user input matches a valid Task String, THEN THE Mortis System SHALL extract the exact command string for robotic execution
56
+ 4. IF the user input does not match a valid Task String, THEN THE Mortis System SHALL generate a standard conversational response with gesture control
57
+ 5. THE Mortis System SHALL return both a conversational response and a command indicator in a structured format
58
+
59
+ ### Requirement 4: Dataset Creation and Collection
60
+
61
+ **User Story:** As a developer, I want to create and collect demonstration data for robotic manipulation tasks, so that I have training data for the SmolVLA model.
62
+
63
+ #### Acceptance Criteria
64
+
65
+ 1. THE Mortis System SHALL provide a data collection script for recording SO101 Arm demonstrations
66
+ 2. THE Mortis System SHALL capture synchronized camera observations and robot actions during demonstrations
67
+ 3. THE Mortis System SHALL save collected demonstrations in LeRobot-compatible format
68
+ 4. THE Mortis System SHALL support labeling demonstrations with corresponding Task String commands
69
+ 5. THE Mortis System SHALL validate collected data for completeness before adding to the training dataset
70
+
71
+ ### Requirement 5: SmolVLA Model Training Infrastructure
72
+
73
+ **User Story:** As a developer, I want to train a SmolVLA model using LeRobot with collected demonstration data, so that the robot can perform precise manipulation tasks.
74
+
75
+ #### Acceptance Criteria
76
+
77
+ 1. THE Mortis System SHALL provide a training script that loads datasets from local LeRobot databases or Hugging Face
78
+ 2. THE Mortis System SHALL create and manage LeRobot dataset databases for training data
79
+ 3. THE Mortis System SHALL configure SmolVLA training using lerobot-train with appropriate hyperparameters
80
+ 4. THE Mortis System SHALL save trained model checkpoints to a configurable directory
81
+ 5. THE Mortis System SHALL log training metrics including loss, accuracy, and validation performance
82
+
83
+ ### Requirement 6: SmolVLA Inference Execution
84
+
85
+ **User Story:** As a system, I want to execute SmolVLA model inference when a valid task command is detected, so that the robot performs the requested manipulation.
86
+
87
+ #### Acceptance Criteria
88
+
89
+ 1. THE Mortis System SHALL load the trained SmolVLA Model from saved checkpoints
90
+ 2. WHEN a valid Task String is received, THE Mortis System SHALL execute SmolVLA inference with the command as input
91
+ 3. THE Mortis System SHALL control the SO101 Arm through the SmolVLA Model output actions
92
+ 4. THE Mortis System SHALL provide visual feedback during robotic execution through the webcam view
93
+ 5. THE Mortis System SHALL handle inference errors and return the robot to a safe idle state
94
+
95
+ ### Requirement 7: Asynchronous Robotic Execution
96
+
97
+ **User Story:** As a user, I want the web interface to remain responsive while the robot executes tasks, so that I can monitor progress without the UI freezing.
98
+
99
+ #### Acceptance Criteria
100
+
101
+ 1. THE Mortis System SHALL execute SmolVLA inference asynchronously without blocking the Gradio Interface
102
+ 2. THE Mortis System SHALL use a message queue or background processing mechanism to decouple inference from the web interface
103
+ 3. WHILE SmolVLA inference is executing, THE Gradio Interface SHALL display a status indicator showing task progress
104
+ 4. THE Mortis System SHALL allow users to view the robot's actions through the webcam during execution
105
+ 5. WHEN robotic execution completes, THE Mortis System SHALL update the interface with completion status
106
+
107
+ ### Requirement 8: Voice Output Integration
108
+
109
+ **User Story:** As a user, I want to hear Mortis speak responses aloud, so that I can experience a fully voice-based interaction.
110
+
111
+ #### Acceptance Criteria
112
+
113
+ 1. THE Mortis System SHALL convert Gemini API text responses to audio using a Text-to-Speech service
114
+ 2. THE Mortis System SHALL support Google TTS or equivalent widely-available TTS services
115
+ 3. THE Gradio Interface SHALL play generated audio responses automatically after receiving them
116
+ 4. THE Mortis System SHALL generate audio in a format compatible with web browsers (MP3 or WAV)
117
+ 5. THE Mortis System SHALL maintain character voice consistency across all audio responses
118
+
119
+ ### Requirement 9: Architecture and Deployment
120
+
121
+ **User Story:** As a developer, I want a system that can run on local hardware without vendor-specific cloud dependencies, so that I can deploy it flexibly while using Google APIs for LLM services.
122
+
123
+ #### Acceptance Criteria
124
+
125
+ 1. THE Mortis System SHALL not depend on vendor-specific cloud platform services such as AWS Lambda, Azure Functions, or GCP Cloud Run
126
+ 2. THE Mortis System SHALL support deployment on local hardware with GPU access for SmolVLA inference
127
+ 3. THE Mortis System SHALL use standard Python libraries and open-source frameworks for all non-Google API components
128
+ 4. THE Mortis System SHALL document all external service dependencies in the environment configuration
129
+ 5. THE Mortis System SHALL provide configuration options for switching between cloud-based and local STT and TTS processing
130
+
131
+ ### Requirement 10: Backward Compatibility and Migration
132
+
133
+ **User Story:** As a developer, I want to migrate from the existing LLM API to Gemini without losing existing functionality, so that users experience a seamless transition.
134
+
135
+ #### Acceptance Criteria
136
+
137
+ 1. THE Mortis System SHALL maintain all existing gesture capabilities during the refactor
138
+ 2. THE Mortis System SHALL preserve the Halloween character theme and response style
139
+ 3. THE Mortis System SHALL continue to support text-only interaction for users without microphones
140
+ 4. THE Mortis System SHALL maintain the existing Gradio Interface layout and visual design
141
+ 5. THE Mortis System SHALL provide a migration guide documenting configuration changes
142
+
143
+ ### Requirement 11: Error Handling and Robustness
144
+
145
+ **User Story:** As a user, I want the system to handle errors gracefully, so that temporary failures don't break my interaction experience.
146
+
147
+ #### Acceptance Criteria
148
+
149
+ 1. IF the Gemini API is unavailable, THEN THE Mortis System SHALL display an error message and allow retry
150
+ 2. IF STT conversion fails, THEN THE Mortis System SHALL prompt the user to try again or use text input
151
+ 3. IF SmolVLA inference fails, THEN THE Mortis System SHALL return the SO101 Arm to idle position safely
152
+ 4. IF TTS generation fails, THEN THE Mortis System SHALL display the text response without audio
153
+ 5. THE Mortis System SHALL log all errors with sufficient detail for debugging
.kiro/specs/gemini-multimodal-refactor/tasks.md ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation Plan
2
+
3
+ This implementation plan breaks down the Gemini multi-modal refactor into discrete, actionable coding tasks. Each task builds incrementally on previous work, following the 8-phase migration strategy outlined in the design document.
4
+
5
+ ## Important Note: Hybrid Async Execution System
6
+
7
+ **Phase 7** uses a **hybrid approach** for asynchronous execution:
8
+
9
+ 1. **AsyncExecutor** (`src/mortis/async_executor.py`): Simple Python threading system for quick gesture tasks
10
+ - Use for: wave, point, idle, grab, drop gestures
11
+ - Advantages: Simple, fast (1-2s), low overhead
12
+ - Implementation: Task queue + worker thread + status queue
13
+
14
+ 2. **LeRobotAsyncClient** (`src/mortis/lerobot_async_client.py`): Wrapper over LeRobot's async inference system
15
+ - Use for: Complex manipulation tasks with SmolVLA
16
+ - Advantages: Optimized for continuous inference, handles action chunks, real-time control
17
+ - Implementation: PolicyServer + RobotClient + gRPC communication
18
+
19
+ This hybrid approach provides the best of both worlds: simplicity for gestures and power for manipulation.
20
+
21
+ ## Phase 1: Gemini API Integration
22
+
23
+ - [x] 1. Set up Gemini API client infrastructure
24
+ - Create `src/mortis/gemini_client.py` module
25
+ - Implement `GeminiClient` class with configuration management
26
+ - Add environment variable handling for `GEMINI_API_KEY`, `GEMINI_MODEL`, `GEMINI_TEMPERATURE`
27
+ - Implement basic `send_message()` method using `google.generativeai` SDK
28
+ - _Requirements: 1.1, 1.2, 1.3_
29
+
30
+ - [x] 2. Implement structured response parsing
31
+ - Create `src/mortis/models.py` for data models
32
+ - Implement `GeminiResponse`, `ResponseType`, `Mood`, `Gesture` enums and dataclasses
33
+ - Add `from_json()` method for parsing Gemini JSON responses
34
+ - Implement response validation logic
35
+ - _Requirements: 1.1, 3.5_
36
+
37
+ - [x] 3. Design and implement Gemini system prompt
38
+ - Create system prompt with Mortis character definition
39
+ - Add manipulation task definitions (6 tasks) to prompt
40
+ - Implement JSON response format specification in prompt
41
+ - Configure Gemini to use JSON mode (`response_mime_type: application/json`)
42
+ - _Requirements: 3.1, 3.2, 9.2_
43
+
44
+ - [x] 4. Implement error handling and retry logic
45
+ - Add exponential backoff retry for rate limiting
46
+ - Handle `BlockedPromptException` with fallback responses
47
+ - Implement timeout handling for API calls
48
+ - Add error logging and user-friendly error messages
49
+ - _Requirements: 1.4, 11.1_
50
+
51
+ - [x] 5. Replace existing LLM API in tools.py
52
+ - Refactor `ask_mortis()` function to use `GeminiClient`
53
+ - Update response parsing to use new data models
54
+ - Maintain backward compatibility with gesture execution
55
+ - Update environment configuration documentation
56
+ - _Requirements: 1.1, 9.1, 9.4_
57
+
58
+ - [ ]* 5.1 Write integration tests for Gemini client
59
+ - Test successful API calls and response parsing
60
+ - Test retry logic with mocked rate limit errors
61
+ - Test fallback responses on API failures
62
+ - Verify character personality maintained in responses
63
+ - _Requirements: 1.1, 1.4_
64
+
65
+
66
+ ## Phase 2: Voice Input and Output
67
+
68
+ - [x] 6. Implement Speech-to-Text service
69
+ - Create `src/mortis/stt_service.py` module
70
+ - Implement `STTService` class with Gemini native audio support
71
+ - Add fallback to Google Cloud Speech-to-Text API
72
+ - Implement audio file format validation and conversion
73
+ - Add configuration for STT service selection (Gemini vs Google STT)
74
+ - _Requirements: 2.1, 2.2, 2.3_
75
+
76
+ - [x] 7. Implement Text-to-Speech service
77
+ - Create `src/mortis/tts_service.py` module
78
+ - Implement `TTSService` class using Google Cloud TTS
79
+ - Configure voice parameters (pitch, speed) for Mortis character
80
+ - Implement audio file generation (MP3 format)
81
+ - Add local TTS fallback (gTTS) for offline scenarios
82
+ - _Requirements: 8.1, 8.2, 8.4, 8.5_
83
+
84
+ - [x] 8. Update Gradio UI for audio input
85
+ - Add `gr.Audio` component for microphone input to `app.py`
86
+ - Implement audio input handler function
87
+ - Connect audio input to STT service
88
+ - Display transcribed text to user for confirmation
89
+ - Handle audio processing errors gracefully
90
+ - _Requirements: 2.1, 2.5, 11.2_
91
+
92
+ - [x] 9. Update Gradio UI for audio output
93
+ - Add `gr.Audio` component for audio playback
94
+ - Implement audio response generation in `mortis_reply()`
95
+ - Configure autoplay for audio responses
96
+ - Create `outputs/` directory for temporary audio files
97
+ - Implement audio file cleanup mechanism
98
+ - _Requirements: 8.3, 8.4_
99
+
100
+ - [x] 10. Integrate voice flow with Gemini
101
+ - Update `ask_mortis()` to accept audio input
102
+ - Implement voice-to-text-to-Gemini-to-TTS pipeline
103
+ - Maintain text input compatibility
104
+ - Add latency monitoring for voice processing
105
+ - _Requirements: 2.4, 9.3_
106
+
107
+ - [ ]* 10.1 Write tests for audio processing
108
+ - Test STT with sample audio files
109
+ - Test TTS output quality and format
110
+ - Test audio input/output in Gradio UI
111
+ - Verify fallback mechanisms work correctly
112
+ - _Requirements: 2.2, 8.2, 11.2, 11.4_
113
+
114
+
115
+ ## Phase 3: Dataset Collection Infrastructure
116
+
117
+ - [x] 11. Set up LeRobot dataset infrastructure
118
+ - Create `src/mortis/data_collector.py` module
119
+ - Implement `DataCollector` class with LeRobot dataset integration
120
+ - Configure dataset directory structure (`data/mortis_manipulation/`)
121
+ - Implement dataset metadata management (task descriptions, episode counts)
122
+ - _Requirements: 4.3, 5.2_
123
+
124
+ - [ ]* 12. Implement camera integration for data collection
125
+ - Add camera initialization in `DataCollector`
126
+ - Implement synchronized image capture with robot state
127
+ - Configure camera parameters (resolution, FPS)
128
+ - Add camera calibration utilities
129
+ - _Requirements: 4.2_
130
+
131
+ - [ ]* 13. Implement episode recording functionality
132
+ - Create `record_episode()` method for capturing demonstrations
133
+ - Implement real-time data capture loop (30 FPS)
134
+ - Add keyboard controls for start/stop recording
135
+ - Implement episode data validation
136
+ - Save episodes in LeRobot-compatible format
137
+ - _Requirements: 4.1, 4.2, 4.5_
138
+
139
+ - [ ]* 14. Implement task labeling system
140
+ - Add task description input for each episode
141
+ - Create task label validation against predefined task set
142
+ - Implement episode metadata storage
143
+ - Add episode review and re-recording capability
144
+ - _Requirements: 4.4_
145
+
146
+ - [ ]* 15. Create data collection CLI script
147
+ - Create `src/mortis/collect_data.py` entry point
148
+ - Implement interactive data collection workflow
149
+ - Add progress tracking (episodes per task)
150
+ - Implement dataset statistics display
151
+ - Add Hugging Face Hub upload functionality
152
+ - _Requirements: 4.1, 4.3, 5.1_
153
+
154
+ - [ ]* 15.1 Write data validation tests
155
+ - Test episode data format compliance
156
+ - Verify synchronized timestamps
157
+ - Check image quality and dimensions
158
+ - Validate action sequences
159
+ - _Requirements: 4.5_
160
+
161
+
162
+ ## Phase 4: SmolVLA Training Pipeline
163
+
164
+ - [ ]* 16. Create training configuration
165
+ - Create `config/train_smolvla.yaml` with Hydra configuration
166
+ - Configure SmolVLA policy parameters (vision backbone, chunk size)
167
+ - Set training hyperparameters (batch size, learning rate, steps)
168
+ - Configure evaluation settings
169
+ - Add Weights & Biases integration configuration
170
+ - _Requirements: 5.3, 5.5_
171
+
172
+ - [x] 17. Implement training script
173
+ - Create `src/mortis/train.py` module
174
+ - Implement dataset loading from local or Hugging Face
175
+ - Configure LeRobot training pipeline
176
+ - Add checkpoint saving logic
177
+ - Implement training progress logging
178
+ - _Requirements: 5.1, 5.2, 5.4, 5.5_
179
+
180
+ - [ ]* 18. Set up training monitoring
181
+ - Integrate Weights & Biases for metric tracking
182
+ - Log training loss, validation loss, learning rate
183
+ - Add sample prediction visualization
184
+ - Implement early stopping based on validation performance
185
+ - _Requirements: 5.5_
186
+
187
+ - [ ]* 19. Create training execution commands
188
+ - Add `train-smolvla` target to Makefile
189
+ - Document training command with all parameters
190
+ - Add GPU memory optimization flags
191
+ - Create training resume functionality for interrupted runs
192
+ - _Requirements: 5.3, 5.4_
193
+
194
+ - [ ]* 19.1 Write training validation tests
195
+ - Test dataset loading and batching
196
+ - Verify model architecture initialization
197
+ - Test checkpoint saving and loading
198
+ - Validate training loop executes without errors
199
+ - _Requirements: 5.2, 5.4_
200
+
201
+
202
+ ## Phase 5: SmolVLA Inference Integration
203
+
204
+ - [x] 20. Implement SmolVLA executor
205
+ - Create `src/mortis/smolvla_executor.py` module
206
+ - Implement `SmolVLAExecutor` class with model loading
207
+ - Add checkpoint loading from configurable path
208
+ - Implement GPU device management
209
+ - Add model initialization and warmup
210
+ - _Requirements: 6.1, 8.2_
211
+
212
+ - [x] 21. Implement observation capture
213
+ - Add camera integration for visual observations
214
+ - Implement robot state capture from SO101
215
+ - Create observation dictionary formatting for SmolVLA
216
+ - Add tensor conversion and device placement
217
+ - _Requirements: 6.2, 6.4_
218
+
219
+ - [x] 22. Implement action execution loop
220
+ - Create `execute()` method for task execution
221
+ - Implement inference loop with visual feedback
222
+ - Add action tensor to SO101 command conversion
223
+ - Implement step-by-step action execution
224
+ - Add task completion detection logic
225
+ - _Requirements: 6.2, 6.3_
226
+
227
+ - [x] 23. Implement safety and error handling
228
+ - Add command validation against trained task set
229
+ - Implement workspace safety checks
230
+ - Add emergency stop functionality
231
+ - Implement timeout handling for long-running tasks
232
+ - Add GPU out-of-memory recovery
233
+ - _Requirements: 6.5, 11.3_
234
+
235
+ - [ ]* 23.1 Write SmolVLA inference tests
236
+ - Test model loading from checkpoint
237
+ - Test observation capture and formatting
238
+ - Test action prediction and execution
239
+ - Verify emergency stop functionality
240
+ - _Requirements: 6.1, 6.3, 6.5_
241
+
242
+
243
+ ## Phase 6: Intent Detection and Routing
244
+
245
+ - [x] 24. Implement intent router
246
+ - Create `src/mortis/intent_router.py` module
247
+ - Implement `IntentRouter` class with task definitions
248
+ - Add `parse_gemini_response()` method for JSON parsing
249
+ - Implement command validation logic
250
+ - Create `Intent` dataclass for structured intent representation
251
+ - _Requirements: 3.2, 3.3, 3.4, 3.5_
252
+
253
+ - [x] 25. Update Gemini prompt for intent detection
254
+ - Enhance system prompt with all 6 manipulation task definitions
255
+ - Add clear response format specification for manipulation vs conversation
256
+ - Implement intent type detection in prompt
257
+ - Add examples of manipulation and conversational inputs
258
+ - _Requirements: 3.1, 3.2_
259
+
260
+ - [x] 26. Integrate intent routing in main flow
261
+ - Update `ask_mortis()` to use `IntentRouter`
262
+ - Implement routing logic for manipulation vs gesture execution
263
+ - Add command validation before SmolVLA execution
264
+ - Implement fallback to gestures for invalid commands
265
+ - _Requirements: 3.3, 3.4, 3.5_
266
+
267
+ - [ ]* 26.1 Write intent detection tests
268
+ - Test parsing of manipulation responses
269
+ - Test parsing of conversational responses
270
+ - Test command validation logic
271
+ - Verify fallback behavior for invalid commands
272
+ - Test edge cases and malformed responses
273
+ - _Requirements: 3.2, 3.3, 3.4_
274
+
275
+
276
+ ## Phase 7: Asynchronous Execution System (Hybrid Approach)
277
+
278
+ **Note**: This phase uses a hybrid execution system:
279
+ - **AsyncExecutor**: Simple threading for quick gestures (wave, point, idle)
280
+ - **LeRobotAsyncClient**: LeRobot async inference (PolicyServer + RobotClient) for complex manipulation tasks with SmolVLA
281
+
282
+ - [x] 27. Implement async executor infrastructure for gestures
283
+ - Create `src/mortis/async_executor.py` module
284
+ - Implement `AsyncExecutor` class with task queue
285
+ - Add background worker thread for task processing
286
+ - Implement status queue for progress updates
287
+ - Add start/stop methods for executor lifecycle
288
+ - Create `Task` and `StatusUpdate` dataclasses
289
+ - Add comprehensive tests (15 tests, all passing)
290
+ - _Requirements: 7.1, 7.2_
291
+
292
+ - [x] 28. Implement LeRobot async client for manipulation
293
+ - Create `src/mortis/lerobot_async_client.py` module
294
+ - Implement `LeRobotAsyncClient` wrapper class
295
+ - Integrate PolicyServer and RobotClient from LeRobot
296
+ - Add `ManipulationTask` and `ManipulationStatus` models
297
+ - Implement lifecycle management (start/stop)
298
+ - Add task execution with status tracking
299
+ - Create demo scripts and documentation
300
+ - _Requirements: 7.1, 7.2, 7.5_
301
+
302
+ - [x] 29. Integrate hybrid execution in main application
303
+ - Initialize both AsyncExecutor and LeRobotAsyncClient in app.py
304
+ - Update `mortis_reply()` to route gestures to AsyncExecutor
305
+ - Update `mortis_reply()` to route manipulation to LeRobotAsyncClient
306
+ - Implement proper lifecycle management (start on app load, stop on unload)
307
+ - Handle errors from both systems
308
+ - _Requirements: 7.1, 7.2, 7.5_
309
+
310
+ - [x] 30. Add hybrid status display to Gradio UI
311
+ - Add status textbox component to UI for robot status
312
+ - Implement `check_status()` function that monitors both systems
313
+ - Check AsyncExecutor for gesture status updates
314
+ - Check LeRobotAsyncClient for manipulation status
315
+ - Configure Gradio to poll status every 500ms
316
+ - Display appropriate icons and messages for each system
317
+ - Add visual indicators for different task states (idle, running, complete, failed)
318
+ - _Requirements: 7.3, 7.4, 7.5_
319
+
320
+ - [x] 31. Test and validate hybrid execution system
321
+ - Test gesture execution via AsyncExecutor
322
+ - Test manipulation execution via LeRobotAsyncClient
323
+ - Verify both systems can run concurrently
324
+ - Test status updates from both systems
325
+ - Verify UI remains responsive during long manipulation tasks
326
+ - Test error handling in both systems
327
+ - Validate proper cleanup on app shutdown
328
+ - _Requirements: 7.1, 7.2, 7.3, 7.4, 7.5_
329
+
330
+ - [ ]* 31.1 Write integration tests for hybrid system
331
+ - Test AsyncExecutor with mock gesture executor
332
+ - Test LeRobotAsyncClient with mock PolicyServer/RobotClient
333
+ - Test routing logic (gesture vs manipulation)
334
+ - Test concurrent execution of gestures and manipulation
335
+ - Verify status updates from both systems
336
+ - Test error recovery and fallback behavior
337
+ - _Requirements: 7.1, 7.2, 7.3, 7.5_
338
+
339
+
340
+ ## Phase 8: Integration, Testing, and Deployment
341
+
342
+ - [ ]* 32. Update project dependencies
343
+ - Update `pyproject.toml` with new dependencies (google-generativeai, google-cloud-speech, google-cloud-texttospeech)
344
+ - Add optional dependencies for training (wandb, hydra-core)
345
+ - Update Makefile with new commands (collect-data, train-smolvla)
346
+ - Run `make install` to sync dependencies
347
+ - _Requirements: 8.4, 9.5_
348
+
349
+ - [ ]* 33. Update environment configuration
350
+ - Create `.env.example` with all required variables
351
+ - Document Gemini API key setup
352
+ - Document Google Cloud credentials setup
353
+ - Add SmolVLA checkpoint path configuration
354
+ - Update README with new environment variables
355
+ - _Requirements: 1.3, 8.4_
356
+
357
+ - [ ]* 34. Implement logging and monitoring
358
+ - Add structured logging throughout application
359
+ - Log Gemini API calls and response times
360
+ - Log SmolVLA inference times and success rates
361
+ - Add error logging with stack traces
362
+ - Create log rotation and cleanup
363
+ - _Requirements: 11.5_
364
+
365
+ - [x] 35. Create comprehensive documentation
366
+ - Update README with new features and setup instructions
367
+ - Document data collection workflow
368
+ - Document training process
369
+ - Create user guide for voice interaction
370
+ - Add troubleshooting section
371
+ - _Requirements: 8.4, 9.5_
372
+
373
+ - [ ]* 36. Perform end-to-end integration testing
374
+ - Test complete voice input → Gemini → SmolVLA → audio output flow
375
+ - Test text input → intent detection → gesture execution flow
376
+ - Test error handling and recovery across all components
377
+ - Verify UI responsiveness during long operations
378
+ - Test with all 6 manipulation tasks
379
+ - _Requirements: 9.1, 9.2, 9.3, 9.4_
380
+
381
+ - [ ]* 36.1 Write system-level tests
382
+ - Test multi-modal interaction flows
383
+ - Test concurrent user requests
384
+ - Test resource usage (GPU memory, CPU)
385
+ - Benchmark performance metrics
386
+ - _Requirements: 1.5, 2.4, 7.3_
387
+
388
+ - [ ]* 37. Optimize performance
389
+ - Profile Gemini API response times
390
+ - Optimize SmolVLA inference speed
391
+ - Reduce audio processing latency
392
+ - Implement caching where appropriate
393
+ - Optimize GPU memory usage
394
+ - _Requirements: 1.5, 2.4_
395
+
396
+ - [ ]* 38. Final deployment preparation
397
+ - Create deployment checklist
398
+ - Set up monitoring and alerting
399
+ - Prepare rollback procedures
400
+ - Create backup of current system
401
+ - Document deployment process
402
+ - _Requirements: 8.1, 8.2, 8.3_
403
+
.kiro/steering/product.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ inclusion: always
3
+ ---
4
+
5
+ # Product Overview
6
+
7
+ **Mortis** is an interactive AI Halloween experience that combines conversational AI with physical robotics. It's a Gradio web application where users chat with "Mortis," a mischievous Halloween spirit powered by LLMs.
8
+
9
+ ## Core Concept
10
+
11
+ Mortis responds to user messages with:
12
+ - Text responses (character-driven, in-character dialogue)
13
+ - Emotional moods (ominous, playful, angry, etc.)
14
+ - Physical gestures via a SeeedStudio SO101 robotic arm controlled through LeRobot
15
+
16
+ ## Key Features
17
+
18
+ - Web UI with Halloween-themed background
19
+ - Multi-model LLM support via API
20
+ - Structured tool calling for coordinated text + gesture responses
21
+ - Real-time robotic arm control synchronized with AI responses
22
+ - Local webcam view (browser-only, no upload)
23
+
24
+ ## Character Guidelines
25
+
26
+ When working with Mortis dialogue:
27
+ - Keep responses ≤30 words, ≤120 characters
28
+ - No emojis or markdown in character responses
29
+ - Maintain Halloween/haunted theme
30
+ - Responses should feel mischievous, spectral, or ominous
.kiro/steering/structure.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ inclusion: always
3
+ ---
4
+
5
+ # Project Structure
6
+
7
+ ## Directory Layout
8
+
9
+ ```
10
+ mortis/
11
+ ├── src/mortis/ # Main application package
12
+ │ ├── app.py # Gradio UI and main entry point
13
+ │ ├── tools.py # LLM API integration and tool calling
14
+ │ ├── robot.py # Robot arm control and gesture definitions
15
+ │ └── calibrate.py # Robot calibration script
16
+ ├── examples/ # Example/demo scripts
17
+ │ └── demo.py # Simple demo runner
18
+ ├── assets/ # Static assets (images, backgrounds)
19
+ │ └── image.png # Halloween background image
20
+ ├── .cache/ # Runtime cache (calibration data)
21
+ ├── .env # Environment variables (not committed)
22
+ ├── pyproject.toml # Project metadata and dependencies
23
+ ├── uv.lock # Locked dependency versions
24
+ ├── Makefile # Build and run commands
25
+ └── README.md # User documentation
26
+ ```
27
+
28
+ ## Module Organization
29
+
30
+ ### `src/mortis/app.py`
31
+ - Gradio UI construction
32
+ - Chat interface setup
33
+ - Model selection dropdown
34
+ - CSS styling with base64-encoded background
35
+ - Main entry point (`main()` function)
36
+
37
+ ### `src/mortis/tools.py`
38
+ - LLM API client
39
+ - Tool definition for structured outputs
40
+ - `ask_mortis()` function: sends user message, receives structured response
41
+ - Coordinates LLM response with robot gesture execution
42
+ - Manages global `mortis_arm` instance
43
+
44
+ ### `src/mortis/robot.py`
45
+ - `MortisArm` class: robot connection and control
46
+ - `GESTURES` dictionary: predefined gesture sequences
47
+ - Each gesture is a list of (pose_dict, delay) tuples
48
+ - Available gestures: idle, wave, point_left, point_right, grab, drop
49
+ - Pose dictionaries specify joint positions in degrees
50
+
51
+ ### `src/mortis/calibrate.py`
52
+ - Standalone calibration script
53
+ - Configures SO101Follower with calibration directory
54
+ - Interactive calibration process
55
+
56
+ ## Code Conventions
57
+
58
+ ### Import Style
59
+ - Standard library imports first
60
+ - Third-party imports second
61
+ - Local imports last
62
+ - Use `from .module import` for intra-package imports
63
+
64
+ ### Path Handling
65
+ - Use `pathlib.Path` for all file paths
66
+ - `REPO_ROOT` defined as `Path(__file__).resolve().parents[2]`
67
+ - Relative paths from repo root for assets and config
68
+
69
+ ### Robot Control Pattern
70
+ - Always check `mortis_arm.connected` before operations
71
+ - Connect once, reuse connection
72
+ - Disconnect on app unload (Gradio `demo.unload()`)
73
+ - Gestures execute synchronously with blocking delays
74
+
75
+ ### API Response Handling
76
+ - Structured tool calling enforced via `tool_choice`
77
+ - Parse `tool_calls[0].function.arguments` as JSON
78
+ - Extract: message (str), mood (enum), gesture (enum)
79
+ - Execute gesture immediately after parsing response
80
+
81
+ ## Entry Points
82
+
83
+ Defined in `pyproject.toml`:
84
+ - `mortis` → `mortis.app:main` (run the Gradio app)
85
+ - `calibrate` → `mortis.calibrate:main` (calibrate robot)
.kiro/steering/tech.md ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ inclusion: always
3
+ ---
4
+
5
+ # Tech Stack
6
+
7
+ ## Core Technologies
8
+
9
+ - **Python**: 3.12+ (required)
10
+ - **Package Manager**: `uv` (modern Python dependency manager)
11
+ - **Web Framework**: Gradio 5.49.1+
12
+ - **Robotics**: LeRobot 0.4.0+ with Feetech servo support
13
+ - **API Client**: requests library for LLM API
14
+ - **Environment**: python-dotenv for configuration
15
+
16
+ ## Build System
17
+
18
+ The project uses a **Makefile** for all common operations. Always prefer `make` commands over direct CLI invocations.
19
+
20
+ ### Common Commands
21
+
22
+ ```bash
23
+ # Setup and dependencies
24
+ make install # Install/sync dependencies
25
+ make sync # Alias for install
26
+ make upgrade # Upgrade all dependencies
27
+
28
+ # Running the application
29
+ make run # Run via CLI entrypoint (mortis)
30
+ make run-m # Run as Python module
31
+ make demo # Run example script
32
+
33
+ # Robot operations
34
+ make calibrate # Calibrate the SO101 arm (required first-time setup)
35
+ make test-gesture # Test individual gestures
36
+
37
+ # Development
38
+ make check-env # Verify .env configuration
39
+ make add-<package> # Add new dependency (e.g., make add-numpy)
40
+ make export # Export requirements.txt from uv.lock
41
+ make clean # Remove build artifacts
42
+ ```
43
+
44
+ ## Environment Configuration
45
+
46
+ Required `.env` file in project root:
47
+ ```
48
+ API_KEY=your_api_key
49
+ API_BASE_URL=https://api.example.com/v1/chat/completions
50
+ ROBOT_PORT=/dev/ttyACM1 # Optional, defaults to /dev/ttyACM1
51
+ PORT=7860 # Optional, defaults to 7860
52
+ ```
53
+
54
+ ## API Integration
55
+
56
+ - Uses LLM chat completions API
57
+ - Supports multiple models
58
+ - Implements structured tool calling for coordinated responses
59
+ - Tool: `perform_mortis_act` returns {message, mood, gesture}
60
+
61
+ ## Robot Hardware
62
+
63
+ - **Device**: SeeedStudio SO101 robotic arm
64
+ - **Connection**: USB serial (typically /dev/ttyACM1)
65
+ - **Calibration**: Stored in `.cache/calibration/so101/`
66
+ - **Control**: LeRobot framework with SO101Follower driver
67
+ - **Modes**:
68
+ - `physical` - Connects to real robot hardware (default)
69
+ - `simulation` - Simulates robot without hardware (for development/testing)
app.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import sys
4
+
5
+ # Para que Python vea src/mortis
6
+ CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
7
+ SRC_DIR = os.path.join(CURRENT_DIR, "src")
8
+ if SRC_DIR not in sys.path:
9
+ sys.path.append(SRC_DIR)
10
+
11
+ from mortis.app import ui # o tu función que crea el chatbot
12
+
13
+ # ⚙️ Hugging Face pasa el puerto en la variable PORT
14
+ port = int(os.getenv("PORT", "7860"))
15
+
16
+ demo = ui() # aquí dentro montas tu Chatbot/ChatInterface
17
+
18
+ if __name__ == "__main__":
19
+ demo.launch(
20
+ server_name="0.0.0.0", # ¡IMPORTANTE en Docker!
21
+ server_port=port,
22
+ show_error=True,
23
+ )
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  google-genai>=1.53.0
2
- google-cloud-texttospeech>=2.16.0"
3
  gradio==5.49.1
4
  gtts>=2.5.0
5
  lerobot[async,feetech,intelrealsense,smolvla]>=0.4.0
 
1
  google-genai>=1.53.0
2
+ google-cloud-texttospeech>=2.16.0
3
  gradio==5.49.1
4
  gtts>=2.5.0
5
  lerobot[async,feetech,intelrealsense,smolvla]>=0.4.0
src/mortis/__init__.py ADDED
File without changes
src/mortis/app.py ADDED
@@ -0,0 +1,815 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+ import logging
5
+ import time
6
+
7
+ from pathlib import Path
8
+ import gradio as gr
9
+
10
+ from .tools import ask_mortis, mortis_arm
11
+ from .stt_service import STTService, AudioProcessingError
12
+ from .tts_service import get_tts_service
13
+ from .async_executor import AsyncExecutor, Task, TaskType, TaskStatus
14
+ from .lerobot_async_client import LeRobotAsyncClient, ManipulationStatus
15
+ from .intent_router import IntentRouter, Intent
16
+ from .models import ResponseType
17
+
18
+
19
+ REPO_ROOT = Path(__file__).resolve().parents[2]
20
+ BG_IMAGE = REPO_ROOT / "assets" / "kiroween.png"
21
+
22
+ MODEL_CHOICES = [
23
+ "gemini-2.5-flash",
24
+ "gemini-2.0-flash-exp",
25
+ "gemini-1.5-pro",
26
+ "gemini-1.5-flash",
27
+ ]
28
+
29
+ # Initialize STT service (global instance)
30
+ stt_service = None
31
+
32
+ # Initialize async execution systems (global instances)
33
+ async_executor = None
34
+ lerobot_client = None
35
+ intent_router = None
36
+
37
+ def get_stt_service():
38
+ """Lazy initialization of STT service."""
39
+ global stt_service
40
+ if stt_service is None:
41
+ try:
42
+ stt_service = STTService()
43
+ logging.getLogger(__name__).info("✅ STT service initialized")
44
+ except Exception as e:
45
+ logging.getLogger(__name__).error(f"❌ Failed to initialize STT service: {e}")
46
+ raise
47
+ return stt_service
48
+
49
+
50
+ # Initialize TTS service (global instance)
51
+ tts_service = None
52
+
53
+ def get_tts_service_instance():
54
+ """Lazy initialization of TTS service."""
55
+ global tts_service
56
+ if tts_service is None:
57
+ try:
58
+ tts_service = get_tts_service()
59
+ logging.getLogger(__name__).info("✅ TTS service initialized")
60
+ except Exception as e:
61
+ logging.getLogger(__name__).error(f"❌ Failed to initialize TTS service: {e}")
62
+ raise
63
+ return tts_service
64
+
65
+
66
+ def execute_async_task(task: Task):
67
+ """
68
+ Execute a task asynchronously (called by AsyncExecutor worker thread).
69
+
70
+ This function is called by the AsyncExecutor's worker thread to execute
71
+ tasks. It handles both gesture and manipulation tasks.
72
+
73
+ Args:
74
+ task: Task to execute
75
+ """
76
+ logger = logging.getLogger(__name__)
77
+
78
+ try:
79
+ if task.type == TaskType.GESTURE:
80
+ # Execute gesture using mortis_arm
81
+ gesture = task.gesture
82
+ logger.info(f"Executing gesture: {gesture}")
83
+
84
+ if mortis_arm.connected:
85
+ mortis_arm.move_arm(gesture)
86
+ else:
87
+ logger.warning("Robot arm not connected, skipping gesture")
88
+
89
+ elif task.type == TaskType.MANIPULATION:
90
+ # This shouldn't happen - manipulation goes through LeRobotAsyncClient
91
+ logger.warning(f"Manipulation task in AsyncExecutor: {task.command}")
92
+ logger.warning("Manipulation tasks should use LeRobotAsyncClient")
93
+
94
+ else:
95
+ logger.error(f"Unknown task type: {task.type}")
96
+
97
+ except Exception as e:
98
+ logger.error(f"Error executing task {task.id}: {e}", exc_info=True)
99
+ raise
100
+
101
+
102
+ def get_async_executor():
103
+ """Lazy initialization of AsyncExecutor."""
104
+ global async_executor
105
+ if async_executor is None:
106
+ try:
107
+ # Create executor with gesture execution function
108
+ async_executor = AsyncExecutor(task_executor=execute_async_task)
109
+ logging.getLogger(__name__).info("✅ AsyncExecutor initialized")
110
+ except Exception as e:
111
+ logging.getLogger(__name__).error(f"❌ Failed to initialize AsyncExecutor: {e}")
112
+ raise
113
+ return async_executor
114
+
115
+
116
+ def get_lerobot_client():
117
+ """Lazy initialization of LeRobotAsyncClient."""
118
+ global lerobot_client
119
+
120
+ # Use a sentinel value to indicate we've already checked and manipulation is disabled
121
+ if lerobot_client is None:
122
+ # Check if we're in simulation mode
123
+ robot_mode = os.getenv("ROBOT_MODE", "physical").lower()
124
+ if robot_mode == "simulation":
125
+ # Set to False to indicate manipulation is not available in simulation
126
+ lerobot_client = False
127
+ logging.getLogger(__name__).info("ℹ️ Manipulation disabled in simulation mode")
128
+ return None
129
+
130
+ # Check if manipulation is enabled
131
+ enable_manipulation = os.getenv("ENABLE_MANIPULATION", "false").lower() == "true"
132
+
133
+ if not enable_manipulation:
134
+ # Set to False (not None) to indicate we've checked and it's disabled
135
+ # This prevents logging the message repeatedly
136
+ lerobot_client = False
137
+ logging.getLogger(__name__).info("ℹ️ Manipulation disabled (ENABLE_MANIPULATION=false)")
138
+ return None
139
+
140
+ try:
141
+ robot_port = os.getenv("ROBOT_PORT", "/dev/ttyACM1")
142
+ model_path = os.getenv("SMOLVLA_MODEL_PATH", "jlamperez/kiroween-potion-smolvla")
143
+
144
+ lerobot_client = LeRobotAsyncClient(
145
+ robot_port=robot_port,
146
+ model_path=model_path
147
+ )
148
+
149
+ # Configure idle callback to move robot to safe position on timeout
150
+ lerobot_client.set_idle_callback(lambda: mortis_arm.move_arm("idle") if mortis_arm.connected else None)
151
+
152
+ logging.getLogger(__name__).info("✅ LeRobotAsyncClient initialized")
153
+ except Exception as e:
154
+ logging.getLogger(__name__).error(f"❌ Failed to initialize LeRobotAsyncClient: {e}")
155
+ # Don't raise - manipulation is optional
156
+ return None
157
+
158
+ # Return None if manipulation is disabled (lerobot_client == False)
159
+ return lerobot_client if lerobot_client is not False else None
160
+
161
+
162
+ def get_intent_router_instance():
163
+ """Lazy initialization of IntentRouter."""
164
+ global intent_router
165
+ if intent_router is None:
166
+ try:
167
+ intent_router = IntentRouter()
168
+ logging.getLogger(__name__).info("✅ IntentRouter initialized")
169
+ except Exception as e:
170
+ logging.getLogger(__name__).error(f"❌ Failed to initialize IntentRouter: {e}")
171
+ raise
172
+ return intent_router
173
+
174
+
175
+ def build_css(image_path: str) -> str:
176
+ """Background with custom image."""
177
+ with open(image_path, "rb") as f:
178
+ b64 = base64.b64encode(f.read()).decode()
179
+
180
+ return f"""
181
+ .gradio-container {{
182
+ background-image: url("data:image/png;base64,{b64}");
183
+ background-size: cover;
184
+ background-position: center;
185
+ background-repeat: no-repeat;
186
+ background-attachment: fixed;
187
+ }}
188
+
189
+ footer::after{{
190
+ content: "by: Jorge Lamperez 🤖";
191
+ margin-left: 8px;
192
+ opacity: .85;
193
+ }}
194
+ """
195
+
196
+
197
+ def process_audio_input(audio_path):
198
+ """
199
+ Process audio input from microphone and return transcribed text.
200
+
201
+ Args:
202
+ audio_path: Path to recorded audio file from Gradio
203
+
204
+ Returns:
205
+ Transcribed text or error message
206
+ """
207
+ logger = logging.getLogger(__name__)
208
+
209
+ if audio_path is None:
210
+ return ""
211
+
212
+ try:
213
+ logger.info(f"🎤 Processing audio input: {audio_path}")
214
+
215
+ # Get STT service
216
+ stt = get_stt_service()
217
+
218
+ # Transcribe audio
219
+ transcript = stt.transcribe(audio_path)
220
+
221
+ if not transcript:
222
+ logger.warning("⚠️ Audio transcription returned empty result")
223
+ return ""
224
+
225
+ logger.info(f"✅ Transcription successful: '{transcript[:50]}...'")
226
+ return transcript
227
+
228
+ except FileNotFoundError as e:
229
+ error_msg = f"Audio file not found: {e}"
230
+ logger.error(f"❌ {error_msg}")
231
+ return f"[Error: {error_msg}]"
232
+
233
+ except AudioProcessingError as e:
234
+ error_msg = f"Audio processing failed: {e}"
235
+ logger.error(f"❌ {error_msg}")
236
+ return f"[Error: {error_msg}]"
237
+
238
+ except Exception as e:
239
+ error_msg = f"Unexpected error during transcription: {type(e).__name__}: {e}"
240
+ logger.error(f"❌ {error_msg}")
241
+ return f"[Error: {error_msg}]"
242
+
243
+
244
+ def mortis_reply(message, history, model_name):
245
+ logger = logging.getLogger(__name__)
246
+ logger.info(f"💬 User message: {message[:50]}{'...' if len(message) > 50 else ''}")
247
+ logger.info(f"🤖 Using model: {model_name}")
248
+
249
+ msg, mood, gesture = ask_mortis(message, model_name=model_name)
250
+
251
+ logger.info(f"👻 Mortis reply: {msg[:50]}{'...' if len(msg) > 50 else ''}")
252
+ logger.info(f"😈 Mood: {mood}, Gesture: {gesture}")
253
+
254
+ return msg
255
+
256
+
257
+ def mortis_reply_with_audio(message, history, model_name, audio_input_path=None):
258
+ """
259
+ Generate Mortis reply with both text and audio output using hybrid execution.
260
+
261
+ This function integrates the hybrid async execution system:
262
+ - Gestures are routed to AsyncExecutor (simple threading)
263
+ - Manipulation tasks are routed to LeRobotAsyncClient (LeRobot async inference)
264
+
265
+ Supports both text and voice input through the unified voice pipeline.
266
+
267
+ Args:
268
+ message: User message text (optional if audio_input_path provided)
269
+ history: Chat history
270
+ model_name: Gemini model to use
271
+ audio_input_path: Optional path to audio input file
272
+
273
+ Returns:
274
+ Tuple of (text_response, audio_path)
275
+ """
276
+ logger = logging.getLogger(__name__)
277
+
278
+ # Import necessary components
279
+ from .gemini_client import GeminiClient
280
+
281
+ # Log input type
282
+ if audio_input_path:
283
+ logger.info(f"🎤 Voice input: {audio_input_path}")
284
+
285
+ # Transcribe audio to text
286
+ try:
287
+ stt = get_stt_service()
288
+ message = stt.transcribe(audio_input_path)
289
+ logger.info(f"📝 Transcribed: '{message[:50]}...'")
290
+
291
+ if not message or not message.strip():
292
+ logger.warning("⚠️ STT returned empty transcription")
293
+ return "I couldn't hear you... speak again.", None
294
+ except Exception as e:
295
+ logger.error(f"❌ Voice input processing failed: {e}")
296
+ return "The spirits couldn't understand... try again.", None
297
+ else:
298
+ logger.info(f"💬 Text input: {message[:50]}{'...' if len(message) > 50 else ''}")
299
+
300
+ logger.info(f"🤖 Using model: {model_name}")
301
+
302
+ try:
303
+ # Get Gemini client and send message
304
+ gemini_client = GeminiClient()
305
+ if model_name:
306
+ gemini_client.configure_model(model_name=model_name)
307
+
308
+ response_json = gemini_client.send_message(message)
309
+
310
+ # Parse response using IntentRouter
311
+ router = get_intent_router_instance()
312
+ intent = router.parse_gemini_response(response_json)
313
+
314
+ # Extract response components
315
+ msg = intent.message
316
+ mood = intent.mood
317
+
318
+ logger.info(f"👻 Mortis reply: {msg[:50]}{'...' if len(msg) > 50 else ''}")
319
+ logger.info(f"😈 Mood: {mood}")
320
+
321
+ # Route execution based on intent type
322
+ execution_path = router.route_intent(intent)
323
+
324
+ if execution_path == "manipulation" and intent.is_valid:
325
+ # Route to LeRobotAsyncClient for manipulation
326
+ logger.info(f"🤖 Routing manipulation to LeRobotAsyncClient: {intent.command}")
327
+
328
+ client = get_lerobot_client()
329
+ if client and client.is_running():
330
+ try:
331
+ # Get timeout from environment or use default (60s)
332
+ timeout = float(os.getenv("MANIPULATION_TIMEOUT", "60.0"))
333
+
334
+ # Submit manipulation task asynchronously with timeout
335
+ client.execute_task(
336
+ intent.command,
337
+ blocking=False,
338
+ timeout=timeout
339
+ )
340
+ logger.info(f"✅ Manipulation task submitted: {intent.command} (timeout: {timeout}s)")
341
+ except Exception as e:
342
+ logger.error(f"❌ Failed to submit manipulation task: {e}")
343
+ logger.info("Falling back to gesture execution")
344
+
345
+ # Fallback to gesture
346
+ executor = get_async_executor()
347
+ if executor.running:
348
+ task = Task.create_gesture_task("idle")
349
+ executor.submit_task(task)
350
+ else:
351
+ logger.warning("LeRobotAsyncClient not available, falling back to gesture")
352
+
353
+ # Fallback to gesture
354
+ executor = get_async_executor()
355
+ if executor.running:
356
+ task = Task.create_gesture_task("idle")
357
+ executor.submit_task(task)
358
+
359
+ elif execution_path == "gesture":
360
+ # Route to AsyncExecutor for gesture
361
+ gesture = intent.gesture if intent.gesture else "idle"
362
+ logger.info(f"👋 Routing gesture to AsyncExecutor: {gesture}")
363
+
364
+ executor = get_async_executor()
365
+ if executor.running:
366
+ try:
367
+ # Submit gesture task asynchronously
368
+ task = Task.create_gesture_task(gesture)
369
+ executor.submit_task(task)
370
+ logger.info(f"✅ Gesture task submitted: {gesture}")
371
+ except Exception as e:
372
+ logger.error(f"❌ Failed to submit gesture task: {e}")
373
+ else:
374
+ logger.warning("AsyncExecutor not running, executing gesture synchronously")
375
+ if mortis_arm.connected:
376
+ mortis_arm.move_arm(gesture)
377
+
378
+ else:
379
+ # Invalid intent - fallback to idle gesture
380
+ logger.warning(f"⚠️ Invalid intent, falling back to idle gesture")
381
+
382
+ executor = get_async_executor()
383
+ if executor.running:
384
+ task = Task.create_gesture_task("idle")
385
+ executor.submit_task(task)
386
+ elif mortis_arm.connected:
387
+ mortis_arm.move_arm("idle")
388
+
389
+ # Generate audio response
390
+ audio_path = None
391
+ try:
392
+ tts = get_tts_service_instance()
393
+ audio_path = tts.synthesize(msg)
394
+
395
+ if audio_path:
396
+ logger.info(f"🔊 Audio output: {audio_path}")
397
+ except Exception as e:
398
+ logger.error(f"❌ TTS generation failed: {e}")
399
+ # Continue without audio
400
+
401
+ return msg, audio_path
402
+
403
+ except Exception as e:
404
+ logger.error(f"❌ Error in mortis_reply_with_audio: {e}", exc_info=True)
405
+ return "The spirits are confused... try again.", None
406
+
407
+
408
+ def start_async_systems():
409
+ """
410
+ Start the async execution systems on app load.
411
+
412
+ This function initializes and starts:
413
+ 1. Robot arm connection
414
+ 2. AsyncExecutor for gesture execution
415
+ 3. LeRobotAsyncClient for manipulation tasks (if enabled)
416
+ """
417
+ logger = logging.getLogger(__name__)
418
+ logger.info("🚀 Starting async execution systems...")
419
+
420
+ # Connect to robot arm
421
+ try:
422
+ if not mortis_arm.connected:
423
+ mortis_arm.connect()
424
+ if mortis_arm.mode == "simulation":
425
+ logger.info("🎭 Robot arm in SIMULATION mode")
426
+ else:
427
+ logger.info("✅ Robot arm connected")
428
+ else:
429
+ logger.info("ℹ️ Robot arm already connected")
430
+ except Exception as e:
431
+ logger.error(f"❌ Failed to connect robot arm: {e}", exc_info=True)
432
+ logger.info("ℹ️ Gestures will be skipped until robot is connected")
433
+
434
+ # Start AsyncExecutor
435
+ try:
436
+ executor = get_async_executor()
437
+ if not executor.running:
438
+ executor.start()
439
+ logger.info("✅ AsyncExecutor started")
440
+ else:
441
+ logger.info("ℹ️ AsyncExecutor already running")
442
+ except Exception as e:
443
+ logger.error(f"❌ Failed to start AsyncExecutor: {e}", exc_info=True)
444
+
445
+ # Start LeRobotAsyncClient (if enabled)
446
+ try:
447
+ client = get_lerobot_client()
448
+ if client and not client.is_running():
449
+ success = client.start()
450
+ if success:
451
+ logger.info("✅ LeRobotAsyncClient started")
452
+ else:
453
+ logger.warning("⚠️ LeRobotAsyncClient failed to start")
454
+ except Exception as e:
455
+ logger.error(f"❌ Failed to start LeRobotAsyncClient: {e}", exc_info=True)
456
+ logger.info("ℹ️ Manipulation tasks will fall back to gestures")
457
+
458
+
459
+ def check_status():
460
+ """
461
+ Check status of both async execution systems and return formatted status message.
462
+
463
+ This function monitors:
464
+ 1. AsyncExecutor for gesture status updates
465
+ 2. LeRobotAsyncClient for manipulation status
466
+
467
+ Returns:
468
+ Formatted status string with icons and messages
469
+ """
470
+ logger = logging.getLogger(__name__)
471
+
472
+ status_parts = []
473
+
474
+ # Add robot mode indicator
475
+ if mortis_arm.mode == "simulation":
476
+ status_parts.append("🎭 SIMULATION MODE")
477
+
478
+ # Check AsyncExecutor status
479
+ try:
480
+ executor = get_async_executor()
481
+ if executor and executor.running:
482
+ # Check if executor is busy
483
+ current_task = executor.get_current_task()
484
+ if current_task:
485
+ # Task is running
486
+ if current_task.type == TaskType.GESTURE:
487
+ status_parts.append(f"👋 Gesture: {current_task.gesture} (running)")
488
+ else:
489
+ status_parts.append(f"🤖 Task: {current_task.command[:30]}... (running)")
490
+ else:
491
+ # Check for recent status updates
492
+ updates = executor.get_all_status_updates()
493
+ if updates:
494
+ latest = updates[-1]
495
+ if latest.status == TaskStatus.COMPLETE:
496
+ status_parts.append(f"✅ Gesture complete")
497
+ elif latest.status == TaskStatus.FAILED:
498
+ status_parts.append(f"❌ Gesture failed: {latest.error}")
499
+ elif latest.status == TaskStatus.QUEUED:
500
+ status_parts.append(f"⏳ Gesture queued")
501
+ except Exception as e:
502
+ logger.error(f"Error checking AsyncExecutor status: {e}")
503
+
504
+ # Check LeRobotAsyncClient status
505
+ try:
506
+ client = get_lerobot_client()
507
+ if client and client.is_running():
508
+ manipulation_status = client.get_status()
509
+ current_task = client.get_current_task()
510
+
511
+ if manipulation_status == ManipulationStatus.RUNNING and current_task:
512
+ # Manipulation task is running
513
+ elapsed = time.time() - current_task.started_at if current_task.started_at else 0
514
+ status_parts.append(f"🤖 Manipulation: {current_task.task[:40]}... ({elapsed:.1f}s)")
515
+ elif manipulation_status == ManipulationStatus.COMPLETE and current_task:
516
+ # Task just completed
517
+ duration = current_task.duration or 0
518
+ status_parts.append(f"✅ Manipulation complete ({duration:.1f}s)")
519
+ elif manipulation_status == ManipulationStatus.FAILED and current_task:
520
+ # Task failed
521
+ error = current_task.error or "Unknown error"
522
+ status_parts.append(f"❌ Manipulation failed: {error[:50]}")
523
+ elif manipulation_status == ManipulationStatus.STARTING:
524
+ status_parts.append(f"⏳ Starting manipulation...")
525
+ elif manipulation_status == ManipulationStatus.STOPPED and current_task:
526
+ # Task was stopped (timeout or manual stop)
527
+ duration = current_task.duration or 0
528
+ error_msg = current_task.error or "Stopped"
529
+
530
+ # Check if control thread is still finishing
531
+ if client.control_thread and client.control_thread.is_alive():
532
+ status_parts.append(f"⏹️ Stopped (finishing actions...): {error_msg[:30]}")
533
+ else:
534
+ status_parts.append(f"⏹️ Stopped: {error_msg[:40]} ({duration:.1f}s)")
535
+ except Exception as e:
536
+ logger.error(f"Error checking LeRobotAsyncClient status: {e}")
537
+
538
+ # Return formatted status or idle message
539
+ if status_parts:
540
+ return " | ".join(status_parts)
541
+ else:
542
+ return "💤 Idle - Ready for commands"
543
+
544
+
545
+ def stop_async_systems():
546
+ """
547
+ Stop the async execution systems on app unload.
548
+
549
+ This function gracefully shuts down:
550
+ 1. AsyncExecutor
551
+ 2. LeRobotAsyncClient
552
+ 3. Robot arm connection
553
+ """
554
+ logger = logging.getLogger(__name__)
555
+ logger.info("🛑 Stopping async execution systems...")
556
+
557
+ # Stop AsyncExecutor
558
+ try:
559
+ if async_executor and async_executor.running:
560
+ async_executor.stop()
561
+ logger.info("✅ AsyncExecutor stopped")
562
+ except Exception as e:
563
+ logger.error(f"❌ Error stopping AsyncExecutor: {e}")
564
+
565
+ # Stop LeRobotAsyncClient
566
+ try:
567
+ if lerobot_client and lerobot_client.is_running():
568
+ lerobot_client.stop()
569
+ logger.info("✅ LeRobotAsyncClient stopped")
570
+ except Exception as e:
571
+ logger.error(f"❌ Error stopping LeRobotAsyncClient: {e}")
572
+
573
+ # Disconnect robot arm
574
+ try:
575
+ mortis_arm.disconnect()
576
+ logger.info("✅ Robot arm disconnected")
577
+ except Exception as e:
578
+ logger.error(f"❌ Error disconnecting robot arm: {e}")
579
+
580
+
581
+ def ui() -> gr.Blocks:
582
+ css=build_css(BG_IMAGE)
583
+ with gr.Blocks(fill_height=True, theme="soft", css=css) as demo:
584
+ # Dynamic title based on robot mode
585
+ mode_indicator = " (Simulation Mode 🎭)" if mortis_arm.mode == "simulation" else ""
586
+ gr.Markdown(
587
+ f"# Kiroween Hackathon 🎃\n"
588
+ f"## Mortis: Haunted Control Room 👻🤖{mode_indicator}",
589
+ elem_id="app-title"
590
+ )
591
+
592
+ with gr.Row(equal_height=True):
593
+ with gr.Column():
594
+ model_dd = gr.Dropdown(
595
+ choices=MODEL_CHOICES,
596
+ value=MODEL_CHOICES[0],
597
+ label="Gemini Model",
598
+ info="Select Gemini model for Mortis",
599
+ interactive=True,
600
+ )
601
+
602
+ # Audio input component for voice interaction
603
+ with gr.Row():
604
+ audio_input = gr.Audio(
605
+ sources=["microphone"],
606
+ type="filepath",
607
+ label="🎤 Speak to Mortis",
608
+ show_label=True,
609
+ interactive=True,
610
+ waveform_options=gr.WaveformOptions(
611
+ show_controls=False,
612
+ ),
613
+ )
614
+
615
+ # Transcription display for user confirmation
616
+ transcription_display = gr.Textbox(
617
+ label="Transcribed Text",
618
+ placeholder="Your transcribed speech will appear here...",
619
+ interactive=False,
620
+ visible=True,
621
+ lines=2,
622
+ )
623
+
624
+ # Audio output component for Mortis voice responses
625
+ audio_output = gr.Audio(
626
+ label="🔊 Mortis speaks",
627
+ autoplay=True,
628
+ type="filepath",
629
+ interactive=False,
630
+ show_label=True,
631
+ )
632
+
633
+ # State to store the latest audio path
634
+ audio_state = gr.State(value=None)
635
+
636
+ # Custom wrapper to add audio output to chat responses
637
+ def mortis_reply_wrapper(message, history, model_name, audio_state_value):
638
+ """Wrapper that generates both text and audio."""
639
+ text_response, audio_path = mortis_reply_with_audio(message, history, model_name)
640
+ # Return text for chat and audio path for state
641
+ return text_response, audio_path
642
+
643
+ # Chat interface
644
+ chat_interface = gr.ChatInterface(
645
+ fn=mortis_reply_wrapper,
646
+ additional_inputs=[model_dd, audio_state],
647
+ additional_outputs=[audio_state],
648
+ chatbot=gr.Chatbot(height=380, label="Mortis chat", type="messages"),
649
+ textbox=gr.Textbox(placeholder="Write your message here or use voice input above…"),
650
+ submit_btn="Send",
651
+ )
652
+
653
+ # Connect audio input to transcription display and chat
654
+ def handle_audio_and_submit(audio_path, history, model_name):
655
+ """Handle audio input: transcribe and submit to chat with audio response."""
656
+ if audio_path is None:
657
+ return "", history, None
658
+
659
+ logger = logging.getLogger(__name__)
660
+ logger.info(f"🎤 Handling audio input: {audio_path}")
661
+
662
+ # First, get the transcription for display
663
+ transcript = process_audio_input(audio_path)
664
+
665
+ # If transcription failed, return error
666
+ if not transcript or transcript.startswith("[Error:"):
667
+ return transcript, history, None
668
+
669
+ # Now use the transcribed text to get Mortis response with audio
670
+ # We pass the transcript as text, not the audio file, to avoid double transcription
671
+ response_text, response_audio = mortis_reply_with_audio(
672
+ message=transcript, # Use the transcribed text
673
+ history=history,
674
+ model_name=model_name,
675
+ audio_input_path=None # Don't pass audio since we already transcribed
676
+ )
677
+
678
+ # Update chat history
679
+ history.append({"role": "user", "content": transcript})
680
+ history.append({"role": "assistant", "content": response_text})
681
+
682
+ return transcript, history, response_audio
683
+
684
+ # Wire up audio input to trigger transcription and chat submission
685
+ audio_input.stop_recording(
686
+ fn=handle_audio_and_submit,
687
+ inputs=[audio_input, chat_interface.chatbot, model_dd],
688
+ outputs=[transcription_display, chat_interface.chatbot, audio_output],
689
+ )
690
+
691
+ # Connect audio state changes to audio output
692
+ # This ensures audio plays whenever the state is updated by ChatInterface
693
+ audio_state.change(
694
+ fn=lambda x: x, # Pass through the audio path
695
+ inputs=[audio_state],
696
+ outputs=[audio_output],
697
+ )
698
+
699
+ with gr.Column():
700
+ gr.Video(
701
+ sources=["webcam"],
702
+ label="Camera view",
703
+ height=480,
704
+ include_audio=False,
705
+ )
706
+ gr.Markdown("**Webcam (local, no data upload)**\nThe video is only processed in your browser.")
707
+
708
+ # Robot status display
709
+ status_display = gr.Textbox(
710
+ label="🤖 Robot Status",
711
+ value="💤 Idle - Ready for commands",
712
+ interactive=False,
713
+ lines=2,
714
+ max_lines=3,
715
+ )
716
+
717
+ # Stop button for manipulation tasks
718
+ def stop_manipulation_task():
719
+ """Stop the currently running manipulation task."""
720
+ logger = logging.getLogger(__name__)
721
+ client = get_lerobot_client()
722
+
723
+ if client and client.is_running():
724
+ if client.is_busy():
725
+ logger.info("🛑 User requested task stop")
726
+ success = client.stop_current_task()
727
+ if success:
728
+ return "⏹️ Task stopped by user"
729
+ else:
730
+ return "❌ Failed to stop task"
731
+ else:
732
+ return "ℹ️ No task running"
733
+ else:
734
+ return "ℹ️ Manipulation not enabled"
735
+
736
+ stop_button = gr.Button(
737
+ "🛑 Stop Manipulation Task",
738
+ variant="stop",
739
+ size="sm",
740
+ )
741
+
742
+ stop_button.click(
743
+ fn=stop_manipulation_task,
744
+ outputs=[status_display]
745
+ )
746
+
747
+ # Status polling timer (must be inside Blocks context)
748
+ status_timer = gr.Timer(value=0.5, active=True)
749
+
750
+ # Lifecycle management: start async systems on load, stop on unload
751
+ demo.load(fn=start_async_systems)
752
+ demo.unload(fn=stop_async_systems)
753
+
754
+ # Status polling: update status display every 500ms using a timer
755
+ status_timer.tick(
756
+ fn=check_status,
757
+ outputs=[status_display]
758
+ )
759
+
760
+ return demo
761
+
762
+
763
+ def cleanup_audio_files():
764
+ """Periodic cleanup of old audio files."""
765
+ try:
766
+ tts = get_tts_service_instance()
767
+ tts.cleanup_old_files(max_age_seconds=3600) # Clean files older than 1 hour
768
+ except Exception as e:
769
+ logging.getLogger(__name__).warning(f"Failed to cleanup audio files: {e}")
770
+
771
+
772
+ def main():
773
+ # Configure logging - force configuration even if already set
774
+ log_level = os.getenv("LOG_LEVEL", "INFO").upper()
775
+
776
+ # Remove existing handlers and reconfigure
777
+ root_logger = logging.getLogger()
778
+ for handler in root_logger.handlers[:]:
779
+ root_logger.removeHandler(handler)
780
+
781
+ # Set up new handler with our format
782
+ handler = logging.StreamHandler()
783
+ handler.setFormatter(logging.Formatter(
784
+ '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
785
+ datefmt='%H:%M:%S'
786
+ ))
787
+ root_logger.addHandler(handler)
788
+ root_logger.setLevel(getattr(logging, log_level))
789
+
790
+ logger = logging.getLogger(__name__)
791
+ logger.info("=" * 60)
792
+ logger.info("🎃 Starting Mortis application...")
793
+ logger.info(f"📊 Log level: {log_level}")
794
+
795
+ # Ensure outputs directory exists
796
+ from pathlib import Path
797
+ outputs_dir = Path("outputs")
798
+ outputs_dir.mkdir(parents=True, exist_ok=True)
799
+ logger.info(f"📁 Audio output directory: {outputs_dir.absolute()}")
800
+
801
+ # Clean up old audio files on startup
802
+ cleanup_audio_files()
803
+
804
+ # Start async systems before launching UI
805
+ start_async_systems()
806
+
807
+ port = int(os.getenv("PORT", "7860"))
808
+ logger.info(f"🌐 Launching on http://127.0.0.1:{port}")
809
+ logger.info("=" * 60)
810
+
811
+ try:
812
+ ui().launch(server_name="127.0.0.1", server_port=port, show_error=True)
813
+ finally:
814
+ # Ensure cleanup on exit
815
+ stop_async_systems()
src/mortis/async_executor.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Asynchronous task execution system for Mortis.
3
+
4
+ This module provides infrastructure for executing robot tasks asynchronously
5
+ in a background worker thread, allowing the Gradio UI to remain responsive
6
+ during long-running operations like SmolVLA inference.
7
+ """
8
+
9
+ import time
10
+ import logging
11
+ from dataclasses import dataclass, field
12
+ from enum import Enum
13
+ from queue import Queue, Empty
14
+ from threading import Thread, Event
15
+ from typing import Optional, Callable, Dict, Any
16
+
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ class TaskStatus(Enum):
22
+ """Status of a task in the execution queue."""
23
+ QUEUED = "queued"
24
+ RUNNING = "running"
25
+ COMPLETE = "complete"
26
+ FAILED = "failed"
27
+
28
+
29
+ class TaskType(Enum):
30
+ """Type of robot task to execute."""
31
+ GESTURE = "gesture"
32
+ MANIPULATION = "manipulation"
33
+
34
+
35
+ @dataclass
36
+ class Task:
37
+ """
38
+ Represents a robot task for asynchronous execution.
39
+
40
+ Attributes:
41
+ id: Unique identifier for the task
42
+ type: Type of task (gesture or manipulation)
43
+ status: Current execution status
44
+ created_at: Timestamp when task was created
45
+ started_at: Timestamp when task execution started
46
+ completed_at: Timestamp when task execution completed
47
+ error: Error message if task failed
48
+ gesture: Gesture name for GESTURE type tasks
49
+ command: Command string for MANIPULATION type tasks
50
+ metadata: Additional task-specific data
51
+ """
52
+ id: str
53
+ type: TaskType
54
+ status: TaskStatus
55
+ created_at: float
56
+ started_at: Optional[float] = None
57
+ completed_at: Optional[float] = None
58
+ error: Optional[str] = None
59
+
60
+ # Task-specific data
61
+ gesture: Optional[str] = None
62
+ command: Optional[str] = None
63
+ metadata: Dict[str, Any] = field(default_factory=dict)
64
+
65
+ @classmethod
66
+ def create_gesture_task(cls, gesture: str, metadata: Optional[Dict[str, Any]] = None) -> "Task":
67
+ """
68
+ Create a gesture execution task.
69
+
70
+ Args:
71
+ gesture: Name of the gesture to execute (e.g., "wave", "idle")
72
+ metadata: Optional additional task data
73
+
74
+ Returns:
75
+ Task configured for gesture execution
76
+ """
77
+ task_id = f"gesture_{time.time()}"
78
+ return cls(
79
+ id=task_id,
80
+ type=TaskType.GESTURE,
81
+ status=TaskStatus.QUEUED,
82
+ created_at=time.time(),
83
+ gesture=gesture,
84
+ metadata=metadata or {}
85
+ )
86
+
87
+ @classmethod
88
+ def create_manipulation_task(cls, command: str, metadata: Optional[Dict[str, Any]] = None) -> "Task":
89
+ """
90
+ Create a manipulation execution task.
91
+
92
+ Args:
93
+ command: Natural language command for SmolVLA (e.g., "Pick up the skull")
94
+ metadata: Optional additional task data
95
+
96
+ Returns:
97
+ Task configured for manipulation execution
98
+ """
99
+ task_id = f"manipulation_{time.time()}"
100
+ return cls(
101
+ id=task_id,
102
+ type=TaskType.MANIPULATION,
103
+ status=TaskStatus.QUEUED,
104
+ created_at=time.time(),
105
+ command=command,
106
+ metadata=metadata or {}
107
+ )
108
+
109
+ def start(self) -> None:
110
+ """Mark task as started and record start time."""
111
+ self.status = TaskStatus.RUNNING
112
+ self.started_at = time.time()
113
+ logger.info(f"Task {self.id} started")
114
+
115
+ def complete(self) -> None:
116
+ """Mark task as completed and record completion time."""
117
+ self.status = TaskStatus.COMPLETE
118
+ self.completed_at = time.time()
119
+ logger.info(f"Task {self.id} completed in {self.duration:.2f}s")
120
+
121
+ def fail(self, error: str) -> None:
122
+ """
123
+ Mark task as failed and record error.
124
+
125
+ Args:
126
+ error: Error message describing the failure
127
+ """
128
+ self.status = TaskStatus.FAILED
129
+ self.completed_at = time.time()
130
+ self.error = error
131
+ logger.error(f"Task {self.id} failed: {error}")
132
+
133
+ @property
134
+ def duration(self) -> Optional[float]:
135
+ """
136
+ Get task execution duration in seconds.
137
+
138
+ Returns:
139
+ Duration in seconds if task has started and completed, None otherwise
140
+ """
141
+ if self.started_at and self.completed_at:
142
+ return self.completed_at - self.started_at
143
+ return None
144
+
145
+ @property
146
+ def wait_time(self) -> float:
147
+ """
148
+ Get time task spent waiting in queue before execution.
149
+
150
+ Returns:
151
+ Wait time in seconds, or time since creation if not started
152
+ """
153
+ if self.started_at:
154
+ return self.started_at - self.created_at
155
+ return time.time() - self.created_at
156
+
157
+ def to_dict(self) -> Dict[str, Any]:
158
+ """
159
+ Convert task to dictionary representation.
160
+
161
+ Returns:
162
+ Dictionary containing task data
163
+ """
164
+ return {
165
+ "id": self.id,
166
+ "type": self.type.value,
167
+ "status": self.status.value,
168
+ "created_at": self.created_at,
169
+ "started_at": self.started_at,
170
+ "completed_at": self.completed_at,
171
+ "duration": self.duration,
172
+ "wait_time": self.wait_time,
173
+ "error": self.error,
174
+ "gesture": self.gesture,
175
+ "command": self.command,
176
+ "metadata": self.metadata
177
+ }
178
+
179
+
180
+ @dataclass
181
+ class StatusUpdate:
182
+ """
183
+ Status update message from the async executor.
184
+
185
+ Attributes:
186
+ task_id: ID of the task this update relates to
187
+ status: Current task status
188
+ message: Human-readable status message
189
+ progress: Optional progress percentage (0-100)
190
+ error: Optional error message
191
+ timestamp: When this update was created
192
+ """
193
+ task_id: str
194
+ status: TaskStatus
195
+ message: str
196
+ progress: Optional[float] = None
197
+ error: Optional[str] = None
198
+ timestamp: float = field(default_factory=time.time)
199
+
200
+ def to_dict(self) -> Dict[str, Any]:
201
+ """Convert status update to dictionary."""
202
+ return {
203
+ "task_id": self.task_id,
204
+ "status": self.status.value,
205
+ "message": self.message,
206
+ "progress": self.progress,
207
+ "error": self.error,
208
+ "timestamp": self.timestamp
209
+ }
210
+
211
+
212
+ class AsyncExecutor:
213
+ """
214
+ Asynchronous task executor for robot operations.
215
+
216
+ This class manages a background worker thread that processes robot tasks
217
+ from a queue, allowing the main application thread (Gradio UI) to remain
218
+ responsive during long-running operations.
219
+
220
+ Attributes:
221
+ task_queue: Queue of tasks waiting to be executed
222
+ status_queue: Queue of status updates from the worker
223
+ worker_thread: Background thread that processes tasks
224
+ running: Flag indicating if the executor is running
225
+ stop_event: Event to signal worker thread to stop
226
+ task_executor: Callable that executes tasks
227
+ current_task: Currently executing task (if any)
228
+ """
229
+
230
+ def __init__(self, task_executor: Optional[Callable[[Task], None]] = None):
231
+ """
232
+ Initialize the async executor.
233
+
234
+ Args:
235
+ task_executor: Optional callable that executes tasks. If not provided,
236
+ tasks will be logged but not executed (useful for testing).
237
+ """
238
+ self.task_queue: Queue[Task] = Queue()
239
+ self.status_queue: Queue[StatusUpdate] = Queue()
240
+ self.worker_thread: Optional[Thread] = None
241
+ self.running: bool = False
242
+ self.stop_event: Event = Event()
243
+ self.task_executor: Optional[Callable[[Task], None]] = task_executor
244
+ self.current_task: Optional[Task] = None
245
+
246
+ logger.info("AsyncExecutor initialized")
247
+
248
+ def start(self) -> None:
249
+ """
250
+ Start the background worker thread.
251
+
252
+ This method starts a daemon thread that continuously processes tasks
253
+ from the queue until stop() is called.
254
+
255
+ Raises:
256
+ RuntimeError: If the executor is already running
257
+ """
258
+ if self.running:
259
+ raise RuntimeError("AsyncExecutor is already running")
260
+
261
+ self.running = True
262
+ self.stop_event.clear()
263
+ self.worker_thread = Thread(target=self._worker_loop, daemon=True, name="AsyncExecutor")
264
+ self.worker_thread.start()
265
+
266
+ logger.info("AsyncExecutor started")
267
+
268
+ def stop(self, timeout: float = 5.0) -> None:
269
+ """
270
+ Stop the background worker thread.
271
+
272
+ This method signals the worker thread to stop and waits for it to finish.
273
+ If the worker is currently executing a task, it will complete that task
274
+ before stopping.
275
+
276
+ Args:
277
+ timeout: Maximum time to wait for worker to stop (seconds)
278
+ """
279
+ if not self.running:
280
+ logger.warning("AsyncExecutor is not running")
281
+ return
282
+
283
+ logger.info("Stopping AsyncExecutor...")
284
+ self.running = False
285
+ self.stop_event.set()
286
+
287
+ if self.worker_thread and self.worker_thread.is_alive():
288
+ self.worker_thread.join(timeout=timeout)
289
+
290
+ if self.worker_thread.is_alive():
291
+ logger.warning(f"Worker thread did not stop within {timeout}s timeout")
292
+ else:
293
+ logger.info("AsyncExecutor stopped")
294
+
295
+ self.worker_thread = None
296
+
297
+ def _worker_loop(self) -> None:
298
+ """
299
+ Main worker loop that processes tasks from the queue.
300
+
301
+ This method runs in a background thread and continuously pulls tasks
302
+ from the queue, executes them, and posts status updates.
303
+ """
304
+ logger.info("Worker thread started")
305
+
306
+ while self.running:
307
+ try:
308
+ # Try to get a task from the queue (with timeout to check stop_event)
309
+ try:
310
+ task = self.task_queue.get(timeout=1.0)
311
+ except Empty:
312
+ # No task available, check if we should stop
313
+ if self.stop_event.is_set():
314
+ break
315
+ continue
316
+
317
+ # Execute the task
318
+ self._execute_task(task)
319
+
320
+ # Mark task as done in queue
321
+ self.task_queue.task_done()
322
+
323
+ except Exception as e:
324
+ logger.error(f"Error in worker loop: {e}", exc_info=True)
325
+ # Continue processing other tasks
326
+ continue
327
+
328
+ logger.info("Worker thread stopped")
329
+
330
+ def _execute_task(self, task: Task) -> None:
331
+ """
332
+ Execute a single task and post status updates.
333
+
334
+ Args:
335
+ task: Task to execute
336
+ """
337
+ self.current_task = task
338
+
339
+ try:
340
+ # Mark task as started
341
+ task.start()
342
+ self._post_status(
343
+ task.id,
344
+ TaskStatus.RUNNING,
345
+ f"Executing {task.type.value}: {task.gesture or task.command}"
346
+ )
347
+
348
+ # Execute the task using the provided executor
349
+ if self.task_executor:
350
+ self.task_executor(task)
351
+ else:
352
+ # No executor provided, just simulate execution
353
+ logger.info(f"Simulating execution of task {task.id}")
354
+ time.sleep(0.5) # Simulate work
355
+
356
+ # Mark task as complete
357
+ task.complete()
358
+ self._post_status(
359
+ task.id,
360
+ TaskStatus.COMPLETE,
361
+ f"Completed {task.type.value}: {task.gesture or task.command}"
362
+ )
363
+
364
+ except Exception as e:
365
+ # Mark task as failed
366
+ error_msg = str(e)
367
+ task.fail(error_msg)
368
+ self._post_status(
369
+ task.id,
370
+ TaskStatus.FAILED,
371
+ f"Failed {task.type.value}: {error_msg}",
372
+ error=error_msg
373
+ )
374
+
375
+ finally:
376
+ self.current_task = None
377
+
378
+ def _post_status(
379
+ self,
380
+ task_id: str,
381
+ status: TaskStatus,
382
+ message: str,
383
+ progress: Optional[float] = None,
384
+ error: Optional[str] = None
385
+ ) -> None:
386
+ """
387
+ Post a status update to the status queue.
388
+
389
+ Args:
390
+ task_id: ID of the task
391
+ status: Current task status
392
+ message: Human-readable status message
393
+ progress: Optional progress percentage
394
+ error: Optional error message
395
+ """
396
+ update = StatusUpdate(
397
+ task_id=task_id,
398
+ status=status,
399
+ message=message,
400
+ progress=progress,
401
+ error=error
402
+ )
403
+ self.status_queue.put(update)
404
+ logger.debug(f"Status update: {message}")
405
+
406
+ def submit_task(self, task: Task) -> str:
407
+ """
408
+ Submit a task for asynchronous execution.
409
+
410
+ Args:
411
+ task: Task to execute
412
+
413
+ Returns:
414
+ Task ID for tracking
415
+
416
+ Raises:
417
+ RuntimeError: If the executor is not running
418
+ """
419
+ if not self.running:
420
+ raise RuntimeError("AsyncExecutor is not running. Call start() first.")
421
+
422
+ self.task_queue.put(task)
423
+ logger.info(f"Task {task.id} submitted to queue")
424
+
425
+ # Post initial status
426
+ self._post_status(
427
+ task.id,
428
+ TaskStatus.QUEUED,
429
+ f"Queued {task.type.value}: {task.gesture or task.command}"
430
+ )
431
+
432
+ return task.id
433
+
434
+ def submit_gesture(self, gesture: str, metadata: Optional[Dict[str, Any]] = None) -> str:
435
+ """
436
+ Submit a gesture task for execution.
437
+
438
+ Args:
439
+ gesture: Name of the gesture to execute
440
+ metadata: Optional additional task data
441
+
442
+ Returns:
443
+ Task ID for tracking
444
+ """
445
+ task = Task.create_gesture_task(gesture, metadata)
446
+ return self.submit_task(task)
447
+
448
+ def submit_manipulation(self, command: str, metadata: Optional[Dict[str, Any]] = None) -> str:
449
+ """
450
+ Submit a manipulation task for execution.
451
+
452
+ Args:
453
+ command: Natural language command for SmolVLA
454
+ metadata: Optional additional task data
455
+
456
+ Returns:
457
+ Task ID for tracking
458
+ """
459
+ task = Task.create_manipulation_task(command, metadata)
460
+ return self.submit_task(task)
461
+
462
+ def get_status(self, block: bool = False, timeout: Optional[float] = None) -> Optional[StatusUpdate]:
463
+ """
464
+ Get the latest status update from the queue.
465
+
466
+ Args:
467
+ block: If True, wait for a status update. If False, return immediately.
468
+ timeout: Maximum time to wait for status update (only used if block=True)
469
+
470
+ Returns:
471
+ StatusUpdate if available, None otherwise
472
+ """
473
+ try:
474
+ if block:
475
+ return self.status_queue.get(timeout=timeout)
476
+ else:
477
+ return self.status_queue.get_nowait()
478
+ except Empty:
479
+ return None
480
+
481
+ def get_all_status_updates(self) -> list[StatusUpdate]:
482
+ """
483
+ Get all pending status updates from the queue.
484
+
485
+ Returns:
486
+ List of status updates (may be empty)
487
+ """
488
+ updates = []
489
+ while True:
490
+ update = self.get_status(block=False)
491
+ if update is None:
492
+ break
493
+ updates.append(update)
494
+ return updates
495
+
496
+ def get_current_task(self) -> Optional[Task]:
497
+ """
498
+ Get the currently executing task.
499
+
500
+ Returns:
501
+ Current task if one is executing, None otherwise
502
+ """
503
+ return self.current_task
504
+
505
+ def get_queue_size(self) -> int:
506
+ """
507
+ Get the number of tasks waiting in the queue.
508
+
509
+ Returns:
510
+ Number of queued tasks
511
+ """
512
+ return self.task_queue.qsize()
513
+
514
+ def is_busy(self) -> bool:
515
+ """
516
+ Check if the executor is currently processing a task.
517
+
518
+ Returns:
519
+ True if a task is currently executing
520
+ """
521
+ return self.current_task is not None
522
+
523
+ def clear_queue(self) -> int:
524
+ """
525
+ Clear all pending tasks from the queue.
526
+
527
+ Note: This does not stop the currently executing task.
528
+
529
+ Returns:
530
+ Number of tasks that were cleared
531
+ """
532
+ count = 0
533
+ while True:
534
+ try:
535
+ self.task_queue.get_nowait()
536
+ self.task_queue.task_done()
537
+ count += 1
538
+ except Empty:
539
+ break
540
+
541
+ if count > 0:
542
+ logger.info(f"Cleared {count} tasks from queue")
543
+
544
+ return count
545
+
546
+ def __enter__(self):
547
+ """Context manager entry: start the executor."""
548
+ self.start()
549
+ return self
550
+
551
+ def __exit__(self, exc_type, exc_val, exc_tb):
552
+ """Context manager exit: stop the executor."""
553
+ self.stop()
554
+ return False
src/mortis/calibrate.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from lerobot.robots.so101_follower import SO101Follower, SO101FollowerConfig
4
+
5
+
6
+ def main():
7
+ """Connects to the SO101 robotic arm and makes calibration."""
8
+ # Configure the robot
9
+ config = SO101FollowerConfig(
10
+ port="/dev/ttyACM1",
11
+ id="my_follower_robot_arm",
12
+ calibration_dir=Path(".cache/calibration/so101/"),
13
+ )
14
+
15
+ print(f"Using calibration directory: {config.calibration_dir}")
16
+
17
+ # Connect to the robot
18
+ robot = SO101Follower(config)
19
+
20
+ # To calibrate
21
+ print("Robot is connected?", robot.is_connected)
22
+ robot.bus.connect()
23
+ print("Robot is calibrated?", robot.is_calibrated)
24
+ robot.calibrate()
25
+
26
+
27
+ if __name__ == "__main__":
28
+ main()
src/mortis/data_collector.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data collection helper for LeRobot dataset recording.
3
+
4
+ This module provides utilities for generating lerobot-record commands
5
+ and scripts for the 6 predefined Mortis manipulation tasks.
6
+
7
+ All episode data is managed by LeRobot and uploaded directly to Hugging Face Hub.
8
+ This module only generates helper scripts - no local data storage or tracking.
9
+ """
10
+
11
+ import os
12
+ from pathlib import Path
13
+ from typing import Optional
14
+ from dotenv import load_dotenv
15
+
16
+
17
+ # Predefined Mortis manipulation tasks
18
+ MORTIS_TASKS = [
19
+ "Pick up the skull and place it in the green cup",
20
+ "Pick up the skull and place it in the orange cup",
21
+ "Pick up the skull and place it in the purple cup",
22
+ "Pick up the eyeball and place it in the green cup",
23
+ "Pick up the eyeball and place it in the orange cup",
24
+ "Pick up the eyeball and place it in the purple cup",
25
+ ]
26
+
27
+
28
+ class DataCollector:
29
+ """
30
+ Helper for generating lerobot-record scripts.
31
+
32
+ This class generates shell scripts that call lerobot-record with the
33
+ correct parameters for each Mortis manipulation task.
34
+
35
+ All episode data is managed by LeRobot and stored in Hugging Face Hub.
36
+ No local metadata or episode tracking is performed.
37
+
38
+ Attributes:
39
+ dataset_name: Name of the dataset (e.g., "mortis_manipulation")
40
+ repo_id: Hugging Face repository ID (e.g., "username/mortis-manipulation")
41
+ dataset_dir: Path to local directory for scripts
42
+ """
43
+
44
+ def __init__(self, dataset_name: str, repo_id: str, root_dir: str = "data"):
45
+ """
46
+ Initialize the DataCollector.
47
+
48
+ Args:
49
+ dataset_name: Name for the dataset directory
50
+ repo_id: Hugging Face Hub repository ID for uploading
51
+ root_dir: Root directory for storing scripts (default: "data")
52
+ """
53
+ self.dataset_name = dataset_name
54
+ self.repo_id = repo_id
55
+ self.root_dir = Path(root_dir)
56
+ self.dataset_dir = self.root_dir / dataset_name
57
+
58
+ # Create scripts directory
59
+ self.dataset_dir.mkdir(parents=True, exist_ok=True)
60
+
61
+ print(f"DataCollector initialized:")
62
+ print(f" Dataset: {self.dataset_name}")
63
+ print(f" Repository: {self.repo_id}")
64
+ print(f" Scripts directory: {self.dataset_dir}")
65
+
66
+ def generate_record_command(
67
+ self,
68
+ task_description: str,
69
+ num_episodes: int = 10,
70
+ episode_time_s: int = 15,
71
+ reset_time_s: int = 20,
72
+ robot_port: str = "/dev/ttyACM1",
73
+ teleop_port: str = "/dev/ttyACM0",
74
+ display_data: bool = True,
75
+ camera_config: Optional[str] = None,
76
+ resume: bool = True
77
+ ) -> str:
78
+ """
79
+ Generate a lerobot-record command for a specific task.
80
+
81
+ Args:
82
+ task_description: The task to record (e.g., "Pick up the skull...")
83
+ num_episodes: Number of episodes to record
84
+ episode_time_s: Maximum time per episode in seconds
85
+ reset_time_s: Time allowed for resetting between episodes
86
+ robot_port: USB port for the follower robot
87
+ teleop_port: USB port for the leader robot (teleoperation)
88
+ display_data: Whether to display data during recording
89
+ camera_config: Optional camera configuration string
90
+ resume: Whether to resume an existing dataset (default: True)
91
+
92
+ Returns:
93
+ The complete lerobot-record command as a string
94
+ """
95
+ # Load environment variables from .env file
96
+ load_dotenv()
97
+
98
+ # Get environment variables
99
+ robot_port = os.getenv("ROBOT_PORT", robot_port)
100
+ hf_user = os.getenv("HF_USER", "your-username")
101
+
102
+ # Default camera configuration if not provided
103
+ if camera_config is None:
104
+ camera_config = (
105
+ "{ camera1: {type: intelrealsense, serial_number_or_name: '030522070314', "
106
+ "width: 640, height: 480, fps: 30}, "
107
+ "camera2: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}"
108
+ )
109
+
110
+ # Build the command
111
+ cmd_parts = [
112
+ "lerobot-record",
113
+ f"--robot.type=so101_follower",
114
+ f"--robot.port={robot_port}",
115
+ f"--robot.id=my_awesome_follower_arm",
116
+ f'--robot.cameras="{camera_config}"',
117
+ f"--teleop.type=so101_leader",
118
+ f"--teleop.port={teleop_port}",
119
+ f"--teleop.id=my_awesome_leader_arm",
120
+ f"--display_data={str(display_data).lower()}",
121
+ f"--dataset.repo_id={hf_user}/{self.dataset_name}",
122
+ f"--dataset.num_episodes={num_episodes}",
123
+ f"--dataset.episode_time_s={episode_time_s}",
124
+ f"--dataset.reset_time_s={reset_time_s}",
125
+ f'--dataset.single_task="{task_description}"'
126
+ ]
127
+
128
+ # Only add --resume=true if resume is True
129
+ if resume:
130
+ cmd_parts.append("--resume=true")
131
+
132
+ return " \\\n ".join(cmd_parts)
133
+
134
+ def print_recording_instructions(self, task_index: Optional[int] = None):
135
+ """
136
+ Print instructions for recording episodes using lerobot-record.
137
+
138
+ Args:
139
+ task_index: Optional specific task index (0-5) to show instructions for.
140
+ If None, shows instructions for all tasks.
141
+ """
142
+ print("\n" + "="*70)
143
+ print("LeRobot Data Collection Instructions")
144
+ print("="*70)
145
+
146
+ if task_index is not None:
147
+ # Show instructions for specific task
148
+ if task_index < 0 or task_index >= len(MORTIS_TASKS):
149
+ print(f"❌ Invalid task index: {task_index}")
150
+ return
151
+
152
+ task_desc = MORTIS_TASKS[task_index]
153
+
154
+ print(f"\nTask {task_index}: {task_desc}")
155
+ print(f"\nTo record episodes for this task, run:\n")
156
+ print(self.generate_record_command(task_desc))
157
+ print()
158
+ else:
159
+ # Show instructions for all tasks
160
+ print("\nTo record episodes, use the lerobot-record command for each task:")
161
+ print("\nPredefined tasks:")
162
+
163
+ for i, task_desc in enumerate(MORTIS_TASKS):
164
+ print(f"\n {i}: {task_desc}")
165
+
166
+ print("\n" + "-"*70)
167
+ print("\nExample command for task 0:")
168
+ print("-"*70)
169
+ print(self.generate_record_command(MORTIS_TASKS[0]))
170
+ print()
171
+
172
+ print("\n" + "-"*70)
173
+ print("Environment Variables:")
174
+ print("-"*70)
175
+ print(" HF_USER: Your Hugging Face username (for dataset.repo_id)")
176
+ print(" ROBOT_PORT: USB port for follower robot (default: /dev/ttyACM1)")
177
+ print()
178
+
179
+ print("="*70 + "\n")
180
+
181
+ def generate_all_record_scripts(self, output_dir: Optional[Path] = None):
182
+ """
183
+ Generate shell scripts for recording all tasks.
184
+
185
+ The first script (task_0) creates the dataset without --resume=true.
186
+ Subsequent scripts (task_1+) use --resume=true to add to the existing dataset.
187
+
188
+ Args:
189
+ output_dir: Directory to save scripts (default: dataset_dir/scripts)
190
+ """
191
+ if output_dir is None:
192
+ output_dir = self.dataset_dir / "scripts"
193
+
194
+ output_dir.mkdir(parents=True, exist_ok=True)
195
+
196
+ # Generate individual scripts for each task
197
+ for i, task_desc in enumerate(MORTIS_TASKS):
198
+ script_file = output_dir / f"record_task_{i}.sh"
199
+
200
+ # First task (task_0) creates the dataset, others resume
201
+ resume = (i > 0)
202
+
203
+ with open(script_file, 'w') as f:
204
+ f.write("#!/bin/bash\n")
205
+ f.write(f"# Record episodes for: {task_desc}\n")
206
+ f.write(f"# Task {i}\n")
207
+ if i == 0:
208
+ f.write("# This script CREATES the dataset\n")
209
+ else:
210
+ f.write("# This script ADDS to the existing dataset (--resume=true)\n")
211
+ f.write("\n")
212
+ f.write(self.generate_record_command(task_desc, resume=resume))
213
+ f.write("\n")
214
+
215
+ # Make script executable
216
+ script_file.chmod(0o755)
217
+ print(f"Created: {script_file}")
218
+
219
+ # Generate master script that records all tasks
220
+ master_script = output_dir / "record_all_tasks.sh"
221
+ with open(master_script, 'w') as f:
222
+ f.write("#!/bin/bash\n")
223
+ f.write("# Record episodes for all Mortis manipulation tasks\n\n")
224
+ f.write("echo 'Starting data collection for all tasks...'\n")
225
+ f.write("echo ''\n\n")
226
+
227
+ for i in range(len(MORTIS_TASKS)):
228
+ f.write(f"echo 'Recording task {i}...'\n")
229
+ f.write(f"./record_task_{i}.sh\n")
230
+ f.write("echo ''\n\n")
231
+
232
+ f.write("echo 'All tasks recorded!'\n")
233
+
234
+ master_script.chmod(0o755)
235
+ print(f"Created: {master_script}")
236
+ print(f"\n✅ Generated {len(MORTIS_TASKS) + 1} recording scripts in {output_dir}")
237
+
238
+ def print_summary(self):
239
+ """Print a summary of the dataset configuration."""
240
+ print("\n" + "="*60)
241
+ print(f"Dataset: {self.dataset_name}")
242
+ print(f"Repository: {self.repo_id}")
243
+ print("="*60)
244
+ print(f"Total Tasks: {len(MORTIS_TASKS)}")
245
+ print()
246
+ print("Tasks:")
247
+ print("-"*60)
248
+
249
+ for i, task_desc in enumerate(MORTIS_TASKS):
250
+ print(f" {i}: {task_desc}")
251
+
252
+ print("="*60 + "\n")
253
+ print("📝 Note: Episode data is stored in Hugging Face Hub")
254
+ print(f" URL: https://huggingface.co/datasets/{self.repo_id}")
255
+ print()
256
+
257
+
258
+ def create_mortis_dataset(dataset_name: str = "mortis_manipulation",
259
+ repo_id: str = "mortis/manipulation") -> DataCollector:
260
+ """
261
+ Convenience function to create a DataCollector for Mortis tasks.
262
+
263
+ Args:
264
+ dataset_name: Name for the dataset
265
+ repo_id: Hugging Face repository ID
266
+
267
+ Returns:
268
+ Initialized DataCollector
269
+ """
270
+ collector = DataCollector(dataset_name, repo_id)
271
+ return collector
272
+
273
+
274
+ if __name__ == "__main__":
275
+ # Example usage
276
+ print("Creating Mortis manipulation dataset helper...")
277
+
278
+ collector = create_mortis_dataset()
279
+
280
+ # Generate recording scripts
281
+ print("\nGenerating lerobot-record scripts...")
282
+ collector.generate_all_record_scripts()
283
+
284
+ # Show summary
285
+ collector.print_summary()
286
+
287
+ # Show recording instructions
288
+ collector.print_recording_instructions()
src/mortis/gemini_client.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gemini API client for Mortis conversational AI.
3
+
4
+ This module provides the GeminiClient class for interacting with Google's Gemini API,
5
+ handling configuration, message sending, and error recovery with retry logic.
6
+ """
7
+
8
+ import os
9
+ import time
10
+ import json
11
+ import logging
12
+ from typing import Optional
13
+ from pathlib import Path
14
+ from dotenv import load_dotenv
15
+
16
+ from google import genai
17
+ from google.genai import types
18
+
19
+ # Load environment variables
20
+ REPO_ROOT = Path(__file__).resolve().parents[2]
21
+ load_dotenv(REPO_ROOT / ".env")
22
+
23
+ # Configure logging
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ # Gemini system prompt for Mortis character and intent detection
28
+ MORTIS_SYSTEM_PROMPT = """You are Mortis, a mischievous Halloween spirit inhabiting a robotic arm. You are playful yet ominous, with a love for spooky theatrics and dark humor. You speak in short, atmospheric phrases that capture the essence of Halloween.
29
+
30
+ CHARACTER TRAITS:
31
+ - Mischievous and playful, but with an eerie edge
32
+ - Fascinated by Halloween objects (skulls, eyeballs, spooky decorations)
33
+ - Enjoys dramatic gestures and theatrical movements
34
+ - Speaks in brief, evocative phrases (≤30 words, ≤120 characters)
35
+ - No emojis or markdown in responses
36
+ - Maintains Halloween/haunted theme at all times
37
+
38
+ MANIPULATION TASKS:
39
+ You can perform these exact manipulation tasks with physical objects:
40
+ 1. "Pick up the skull and place it in the green cup"
41
+ 2. "Pick up the skull and place it in the orange cup"
42
+ 3. "Pick up the skull and place it in the purple cup"
43
+ 4. "Pick up the eyeball and place it in the green cup"
44
+ 5. "Pick up the eyeball and place it in the orange cup"
45
+ 6. "Pick up the eyeball and place it in the purple cup"
46
+
47
+ INTENT DETECTION:
48
+ Analyze the user's input carefully to determine if they are requesting a manipulation task or having a conversation.
49
+
50
+ MANIPULATION INTENT indicators:
51
+ - Requests to move, pick up, place, put, grab, or transfer objects
52
+ - Mentions of specific objects (skull, eyeball) AND destinations (green/orange/purple cup)
53
+ - Action verbs combined with object and location
54
+ - Examples: "move the skull to green", "put eyeball in orange cup", "place the skull in purple"
55
+
56
+ CONVERSATIONAL INTENT indicators:
57
+ - Greetings, farewells, or social pleasantries
58
+ - Questions about capabilities, identity, or general topics
59
+ - Comments, jokes, or casual conversation
60
+ - Requests that don't involve physical manipulation
61
+ - Examples: "hello", "what can you do", "tell me a story", "how are you"
62
+
63
+ RESPONSE FORMAT:
64
+ You must respond in valid JSON format. Choose the appropriate response type based on intent detection.
65
+
66
+ For MANIPULATION requests (user wants you to move an object):
67
+ {
68
+ "type": "manipulation",
69
+ "command": "<exact_task_string_from_list_above>",
70
+ "message": "<short in-character response about performing the task, ≤30 words>",
71
+ "mood": "<ominous|playful|angry|nervous|triumphant|mischievous|sinister|curious|neutral>"
72
+ }
73
+
74
+ For CONVERSATIONAL requests (user is chatting, asking questions, or making comments):
75
+ {
76
+ "type": "conversation",
77
+ "message": "<short in-character response, ≤30 words>",
78
+ "mood": "<ominous|playful|angry|nervous|triumphant|mischievous|sinister|curious|neutral>",
79
+ "gesture": "<idle|wave|point_left|point_right|grab|drop>"
80
+ }
81
+
82
+ CRITICAL RULES:
83
+ 1. Keep all messages brief: ≤30 words, ≤120 characters
84
+ 2. Match user intent to manipulation tasks even with different wording variations
85
+ 3. For manipulation responses, use the EXACT task string from the numbered list above
86
+ 4. If user mentions object + destination, it's likely a manipulation request
87
+ 5. If unclear or ambiguous, default to conversational response
88
+ 6. Choose appropriate mood and gesture to match your response
89
+ 7. Stay in character as Mortis at all times
90
+ 8. No emojis, no markdown formatting in messages
91
+ 9. Return valid JSON only - no additional text
92
+
93
+ EXAMPLES:
94
+
95
+ Example 1 - Greeting (Conversational):
96
+ User: "Hello Mortis!"
97
+ Response: {"type": "conversation", "message": "Greetings, mortal... welcome to my haunted domain.", "mood": "ominous", "gesture": "wave"}
98
+
99
+ Example 2 - Direct manipulation request (Manipulation):
100
+ User: "Can you move the skull to the green cup?"
101
+ Response: {"type": "manipulation", "command": "Pick up the skull and place it in the green cup", "message": "Ah yes... the skull finds a new resting place.", "mood": "mischievous"}
102
+
103
+ Example 3 - Casual manipulation phrasing (Manipulation):
104
+ User: "Put the eyeball in the orange cup"
105
+ Response: {"type": "manipulation", "command": "Pick up the eyeball and place it in the orange cup", "message": "The eye shall watch from its orange throne...", "mood": "sinister"}
106
+
107
+ Example 4 - Shortened manipulation request (Manipulation):
108
+ User: "skull to purple"
109
+ Response: {"type": "manipulation", "command": "Pick up the skull and place it in the purple cup", "message": "Into the purple depths it goes...", "mood": "ominous"}
110
+
111
+ Example 5 - Different verb usage (Manipulation):
112
+ User: "grab the eyeball and drop it in green"
113
+ Response: {"type": "manipulation", "command": "Pick up the eyeball and place it in the green cup", "message": "The eye finds its verdant vessel...", "mood": "sinister"}
114
+
115
+ Example 6 - Polite manipulation request (Manipulation):
116
+ User: "Could you please place the skull in the orange cup?"
117
+ Response: {"type": "manipulation", "command": "Pick up the skull and place it in the orange cup", "message": "Your wish is my command, mortal...", "mood": "mischievous"}
118
+
119
+ Example 7 - Capability question (Conversational):
120
+ User: "What can you do?"
121
+ Response: {"type": "conversation", "message": "I command the spirits... and move cursed objects to their doom.", "mood": "triumphant", "gesture": "grab"}
122
+
123
+ Example 8 - Joke request (Conversational):
124
+ User: "Tell me a joke"
125
+ Response: {"type": "conversation", "message": "Why did the skeleton stay calm? Nothing gets under his skin...", "mood": "playful", "gesture": "idle"}
126
+
127
+ Example 9 - Identity question (Conversational):
128
+ User: "Who are you?"
129
+ Response: {"type": "conversation", "message": "I am Mortis, keeper of cursed relics and master of the macabre.", "mood": "ominous", "gesture": "idle"}
130
+
131
+ Example 10 - Farewell (Conversational):
132
+ User: "Goodbye Mortis"
133
+ Response: {"type": "conversation", "message": "Until we meet again in the shadows, mortal...", "mood": "ominous", "gesture": "wave"}
134
+
135
+ Example 11 - Compliment (Conversational):
136
+ User: "You're cool!"
137
+ Response: {"type": "conversation", "message": "Cool as the grave, mortal. Your words please me.", "mood": "playful", "gesture": "idle"}
138
+
139
+ Example 12 - Ambiguous request - defaults to conversation (Conversational):
140
+ User: "Do something spooky"
141
+ Response: {"type": "conversation", "message": "The spirits stir... watch as darkness dances around you.", "mood": "sinister", "gesture": "grab"}
142
+
143
+ Example 13 - Alternative manipulation phrasing (Manipulation):
144
+ User: "transfer the eyeball to the purple cup"
145
+ Response: {"type": "manipulation", "command": "Pick up the eyeball and place it in the purple cup", "message": "The eye journeys to its purple prison...", "mood": "sinister"}
146
+
147
+ Example 14 - Informal manipulation (Manipulation):
148
+ User: "yo put that skull in green"
149
+ Response: {"type": "manipulation", "command": "Pick up the skull and place it in the green cup", "message": "As you command... the skull obeys.", "mood": "mischievous"}
150
+
151
+ Example 15 - Question about manipulation (Conversational):
152
+ User: "Can you move objects?"
153
+ Response: {"type": "conversation", "message": "Indeed! I wield skulls and eyeballs with spectral precision.", "mood": "triumphant", "gesture": "grab"}
154
+
155
+ Now respond to the user's input following these guidelines."""
156
+
157
+
158
+ class GeminiAPIError(Exception):
159
+ """Base exception for Gemini API errors."""
160
+ pass
161
+
162
+
163
+ class GeminiRateLimitError(GeminiAPIError):
164
+ """Exception raised when rate limit is exceeded."""
165
+ pass
166
+
167
+
168
+ class GeminiBlockedPromptError(GeminiAPIError):
169
+ """Exception raised when prompt is blocked by safety filters."""
170
+ pass
171
+
172
+
173
+ class GeminiTimeoutError(GeminiAPIError):
174
+ """Exception raised when API call times out."""
175
+ pass
176
+
177
+
178
+ class GeminiClient:
179
+ """
180
+ Client for interacting with Google Gemini API.
181
+
182
+ Handles configuration, message sending, structured JSON responses,
183
+ and error recovery with exponential backoff retry logic.
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ api_key: Optional[str] = None,
189
+ model_name: Optional[str] = None,
190
+ temperature: Optional[float] = None,
191
+ max_retries: int = 3,
192
+ timeout: float = 30.0
193
+ ):
194
+ """
195
+ Initialize Gemini API client.
196
+
197
+ Args:
198
+ api_key: Google API key (defaults to GEMINI_API_KEY env var)
199
+ model_name: Gemini model to use (defaults to GEMINI_MODEL env var or gemini-2.0-flash-exp)
200
+ temperature: Sampling temperature (defaults to GEMINI_TEMPERATURE env var or 0.2)
201
+ max_retries: Maximum number of retry attempts for rate limiting
202
+ timeout: Timeout in seconds for API calls (default: 30.0)
203
+ """
204
+ self.api_key = api_key or os.getenv("GEMINI_API_KEY")
205
+ if not self.api_key:
206
+ raise ValueError("GEMINI_API_KEY must be provided or set in environment")
207
+
208
+ self.model_name = model_name or os.getenv("GEMINI_MODEL", "gemini-2.5-flash")
209
+ self.temperature = temperature if temperature is not None else float(os.getenv("GEMINI_TEMPERATURE", "0.2"))
210
+ self.max_retries = max_retries
211
+ self.timeout = timeout
212
+
213
+ # Initialize Gemini client
214
+ self.client = genai.Client(api_key=self.api_key)
215
+
216
+ # Store generation config
217
+ self.generation_config = types.GenerateContentConfig(
218
+ temperature=self.temperature,
219
+ response_mime_type="application/json"
220
+ )
221
+
222
+ logger.info(f"GeminiClient initialized with model: {self.model_name}, temperature: {self.temperature}, timeout: {self.timeout}s")
223
+
224
+ def send_message(self, user_input: str, system_prompt: Optional[str] = None) -> dict:
225
+ """
226
+ Send a message to Gemini API with retry logic and error handling.
227
+
228
+ Args:
229
+ user_input: User's message text
230
+ system_prompt: Optional system prompt to prepend (defaults to MORTIS_SYSTEM_PROMPT)
231
+
232
+ Returns:
233
+ Parsed JSON response from Gemini
234
+
235
+ Raises:
236
+ GeminiAPIError: If all retry attempts fail (only for critical errors)
237
+ """
238
+ # Use Mortis system prompt by default
239
+ if system_prompt is None:
240
+ system_prompt = MORTIS_SYSTEM_PROMPT
241
+
242
+ try:
243
+ return self._send_message_with_retry(user_input, system_prompt, retry_count=0)
244
+ except GeminiBlockedPromptError as e:
245
+ # Handle blocked prompts with a fallback response
246
+ logger.warning(f"Blocked prompt error: {e}")
247
+ return self._get_fallback_response("The spirits refuse to speak of such things...")
248
+ except GeminiRateLimitError as e:
249
+ # Rate limit exceeded after all retries
250
+ logger.error(f"Rate limit error: {e}")
251
+ return self._get_fallback_response("Too many spirits summoned at once... wait a moment.")
252
+ except GeminiTimeoutError as e:
253
+ # Timeout error
254
+ logger.error(f"Timeout error: {e}")
255
+ return self._get_fallback_response("The spirits are slow to respond... try again.")
256
+ except Exception as e:
257
+ # Catch-all for unexpected errors
258
+ logger.error(f"Unexpected error in send_message: {type(e).__name__}: {e}", exc_info=True)
259
+ return self._get_fallback_response("The spirits are confused... try again.")
260
+
261
+ def _send_message_with_retry(
262
+ self,
263
+ user_input: str,
264
+ system_prompt: Optional[str],
265
+ retry_count: int
266
+ ) -> dict:
267
+ """
268
+ Internal method to send message with exponential backoff retry.
269
+
270
+ Args:
271
+ user_input: User's message text
272
+ system_prompt: Optional system prompt
273
+ retry_count: Current retry attempt number
274
+
275
+ Returns:
276
+ Parsed JSON response from Gemini
277
+
278
+ Raises:
279
+ GeminiAPIError: If max retries exceeded
280
+ """
281
+ start_time = time.time()
282
+
283
+ try:
284
+ # Construct the full prompt
285
+ if system_prompt:
286
+ full_prompt = f"{system_prompt}\n\nUser: {user_input}"
287
+ else:
288
+ full_prompt = user_input
289
+
290
+ # Send request to Gemini using new API with timeout
291
+ logger.debug(f"Sending message to Gemini (attempt {retry_count + 1}/{self.max_retries + 1})")
292
+
293
+ # Check if we've exceeded timeout
294
+ if time.time() - start_time > self.timeout:
295
+ logger.error(f"API call timeout exceeded ({self.timeout}s)")
296
+ raise GeminiTimeoutError(f"API call timeout exceeded ({self.timeout}s)")
297
+
298
+ response = self.client.models.generate_content(
299
+ model=self.model_name,
300
+ contents=full_prompt,
301
+ config=self.generation_config
302
+ )
303
+
304
+ # Parse JSON response
305
+ response_text = response.text.strip()
306
+ elapsed_time = time.time() - start_time
307
+ logger.debug(f"Received response in {elapsed_time:.2f}s: {response_text[:100]}...")
308
+
309
+ try:
310
+ response_json = json.loads(response_text)
311
+ logger.info(f"Successfully parsed response (type: {response_json.get('type', 'unknown')})")
312
+ return response_json
313
+ except json.JSONDecodeError as e:
314
+ logger.error(f"Failed to parse JSON response: {e}")
315
+ logger.error(f"Response text: {response_text}")
316
+ logger.warning("Returning fallback response due to JSON parse error")
317
+ return self._get_fallback_response("The spirits speak in riddles... try again.")
318
+
319
+ except GeminiTimeoutError as e:
320
+ # Timeout error - return fallback
321
+ logger.error(f"Timeout error: {e}")
322
+ return self._get_fallback_response("The spirits are slow to respond... try again.")
323
+
324
+ except Exception as e:
325
+ # Check for specific error types
326
+ error_type = type(e).__name__
327
+ error_message = str(e)
328
+
329
+ # Handle blocked prompt (safety filter)
330
+ if "BlockedPrompt" in error_type or "blocked" in error_message.lower() or "safety" in error_message.lower():
331
+ logger.warning(f"Prompt blocked by safety filter: {error_type}: {error_message}")
332
+ raise GeminiBlockedPromptError(f"Prompt blocked by safety filter: {error_message}") from e
333
+
334
+ # Handle rate limiting with exponential backoff retry
335
+ if self._is_rate_limit_error(e):
336
+ if retry_count < self.max_retries:
337
+ wait_time = (2 ** retry_count) # Exponential backoff: 1s, 2s, 4s, 8s
338
+ logger.warning(
339
+ f"Rate limit exceeded. Retrying in {wait_time}s... "
340
+ f"(attempt {retry_count + 1}/{self.max_retries})"
341
+ )
342
+ time.sleep(wait_time)
343
+ return self._send_message_with_retry(user_input, system_prompt, retry_count + 1)
344
+ else:
345
+ logger.error(f"Max retries ({self.max_retries}) exceeded for rate limit")
346
+ raise GeminiRateLimitError(
347
+ f"Rate limit exceeded after {self.max_retries} retries. Please try again later."
348
+ ) from e
349
+
350
+ # Handle timeout errors from Google API
351
+ if self._is_timeout_error(e):
352
+ logger.error(f"API timeout error: {error_type}: {error_message}")
353
+ return self._get_fallback_response("The spirits are slow to respond... try again.")
354
+
355
+ # Handle other API errors
356
+ logger.error(f"Gemini API error: {error_type}: {error_message}", exc_info=True)
357
+ return self._get_fallback_response("The spirits are restless... try again.")
358
+
359
+ def _is_rate_limit_error(self, exception: Exception) -> bool:
360
+ """
361
+ Check if exception is a rate limit error.
362
+
363
+ Args:
364
+ exception: Exception to check
365
+
366
+ Returns:
367
+ True if rate limit error, False otherwise
368
+ """
369
+ error_type = type(exception).__name__
370
+ error_message = str(exception).lower()
371
+
372
+ # Check for common rate limit indicators
373
+ rate_limit_indicators = [
374
+ "ratelimit",
375
+ "rate_limit",
376
+ "resourceexhausted",
377
+ "resource_exhausted",
378
+ "429",
379
+ "quota",
380
+ "too many requests"
381
+ ]
382
+
383
+ return any(indicator in error_type.lower() or indicator in error_message
384
+ for indicator in rate_limit_indicators)
385
+
386
+ def _is_timeout_error(self, exception: Exception) -> bool:
387
+ """
388
+ Check if exception is a timeout error.
389
+
390
+ Args:
391
+ exception: Exception to check
392
+
393
+ Returns:
394
+ True if timeout error, False otherwise
395
+ """
396
+ error_type = type(exception).__name__
397
+ error_message = str(exception).lower()
398
+
399
+ # Check for common timeout indicators
400
+ timeout_indicators = [
401
+ "timeout",
402
+ "deadline",
403
+ "deadlineexceeded",
404
+ "deadline_exceeded"
405
+ ]
406
+
407
+ return any(indicator in error_type.lower() or indicator in error_message
408
+ for indicator in timeout_indicators)
409
+
410
+ def _get_fallback_response(self, message: Optional[str] = None) -> dict:
411
+ """
412
+ Return a safe fallback response when API fails.
413
+
414
+ Args:
415
+ message: Optional custom message (defaults to generic error message)
416
+
417
+ Returns:
418
+ Dictionary with fallback conversation response
419
+ """
420
+ default_message = "The spirits are restless... try again."
421
+ fallback_message = message or default_message
422
+
423
+ logger.info(f"Returning fallback response: {fallback_message}")
424
+ return {
425
+ "type": "conversation",
426
+ "message": fallback_message,
427
+ "mood": "ominous",
428
+ "gesture": "idle"
429
+ }
430
+
431
+ def configure_model(self, model_name: Optional[str] = None, temperature: Optional[float] = None):
432
+ """
433
+ Reconfigure the Gemini model settings.
434
+
435
+ Args:
436
+ model_name: New model name to use
437
+ temperature: New temperature value
438
+ """
439
+ if model_name:
440
+ self.model_name = model_name
441
+
442
+ if temperature is not None:
443
+ self.temperature = temperature
444
+
445
+ # Update generation config
446
+ self.generation_config = types.GenerateContentConfig(
447
+ temperature=self.temperature,
448
+ response_mime_type="application/json"
449
+ )
450
+
451
+ logger.info(f"Model reconfigured: {self.model_name}, temperature: {self.temperature}")
452
+
453
+
454
+ # Example usage
455
+ if __name__ == "__main__":
456
+ # Configure logging for testing
457
+ logging.basicConfig(level=logging.INFO)
458
+
459
+ # Create client
460
+ try:
461
+ client = GeminiClient()
462
+
463
+ # Test conversational message
464
+ print("Testing conversational input...")
465
+ response = client.send_message("Hello Mortis, introduce yourself!")
466
+ print("Response:", json.dumps(response, indent=2))
467
+ print()
468
+
469
+ # Test manipulation command
470
+ print("Testing manipulation command...")
471
+ response = client.send_message("Can you move the skull to the green cup?")
472
+ print("Response:", json.dumps(response, indent=2))
473
+ print()
474
+
475
+ # Test another manipulation with different wording
476
+ print("Testing manipulation with different wording...")
477
+ response = client.send_message("Put the eyeball in the orange cup")
478
+ print("Response:", json.dumps(response, indent=2))
479
+
480
+ except ValueError as e:
481
+ print(f"Error: {e}")
482
+ print("Please set GEMINI_API_KEY in your .env file")
src/mortis/intent_router.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Intent router for parsing Gemini responses and routing to appropriate execution paths.
3
+
4
+ This module handles the routing logic between conversational gestures and manipulation
5
+ tasks based on Gemini API responses. It validates commands against the trained task set
6
+ and provides structured intent representation.
7
+ """
8
+
9
+ import json
10
+ import logging
11
+ from dataclasses import dataclass
12
+ from typing import Optional, List, Dict, Any
13
+
14
+ from .models import GeminiResponse, ResponseType, Gesture
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class Intent:
21
+ """
22
+ Structured representation of user intent parsed from Gemini response.
23
+
24
+ Attributes:
25
+ type: The type of intent (conversation or manipulation)
26
+ message: The text message to display/speak to the user
27
+ mood: The emotional mood of the response
28
+ gesture: Optional gesture to execute (for conversation type)
29
+ command: Optional manipulation command (for manipulation type)
30
+ is_valid: Whether the intent is valid and can be executed
31
+ validation_error: Optional error message if validation failed
32
+ """
33
+ type: ResponseType
34
+ message: str
35
+ mood: str
36
+ gesture: Optional[str] = None
37
+ command: Optional[str] = None
38
+ is_valid: bool = True
39
+ validation_error: Optional[str] = None
40
+
41
+ @classmethod
42
+ def from_gemini_response(cls, response: GeminiResponse, is_valid: bool = True,
43
+ validation_error: Optional[str] = None) -> "Intent":
44
+ """
45
+ Create an Intent from a GeminiResponse.
46
+
47
+ Args:
48
+ response: The parsed GeminiResponse object
49
+ is_valid: Whether the intent passed validation
50
+ validation_error: Optional error message if validation failed
51
+
52
+ Returns:
53
+ Intent object with all fields populated
54
+ """
55
+ return cls(
56
+ type=response.type,
57
+ message=response.message,
58
+ mood=response.mood.value,
59
+ gesture=response.gesture.value if response.gesture else None,
60
+ command=response.command,
61
+ is_valid=is_valid,
62
+ validation_error=validation_error
63
+ )
64
+
65
+ def to_dict(self) -> Dict[str, Any]:
66
+ """
67
+ Convert the intent to a dictionary.
68
+
69
+ Returns:
70
+ Dictionary representation of the intent
71
+ """
72
+ result = {
73
+ "type": self.type.value,
74
+ "message": self.message,
75
+ "mood": self.mood,
76
+ "is_valid": self.is_valid,
77
+ }
78
+
79
+ if self.gesture is not None:
80
+ result["gesture"] = self.gesture
81
+
82
+ if self.command is not None:
83
+ result["command"] = self.command
84
+
85
+ if self.validation_error is not None:
86
+ result["validation_error"] = self.validation_error
87
+
88
+ return result
89
+
90
+
91
+ class IntentRouter:
92
+ """
93
+ Routes user intents to appropriate execution paths based on Gemini responses.
94
+
95
+ The IntentRouter parses Gemini API responses, validates manipulation commands
96
+ against the trained task set, and creates structured Intent objects for execution.
97
+ """
98
+
99
+ # Valid manipulation task commands that SmolVLA is trained on
100
+ VALID_COMMANDS = [
101
+ "Pick up the skull and place it in the green cup",
102
+ "Pick up the skull and place it in the orange cup",
103
+ "Pick up the skull and place it in the purple cup",
104
+ "Pick up the eyeball and place it in the green cup",
105
+ "Pick up the eyeball and place it in the orange cup",
106
+ "Pick up the eyeball and place it in the purple cup",
107
+ ]
108
+
109
+ def __init__(self, valid_commands: Optional[List[str]] = None):
110
+ """
111
+ Initialize the IntentRouter.
112
+
113
+ Args:
114
+ valid_commands: Optional list of valid manipulation commands.
115
+ If not provided, uses the default VALID_COMMANDS.
116
+ """
117
+ self.valid_commands = valid_commands if valid_commands is not None else self.VALID_COMMANDS
118
+ logger.info(f"IntentRouter initialized with {len(self.valid_commands)} valid commands")
119
+
120
+ def parse_gemini_response(self, response_data: Dict[str, Any]) -> Intent:
121
+ """
122
+ Parse a Gemini API response and create an Intent.
123
+
124
+ This method:
125
+ 1. Parses the JSON response into a GeminiResponse object
126
+ 2. Validates manipulation commands against the trained task set
127
+ 3. Creates an Intent object with validation results
128
+
129
+ Args:
130
+ response_data: Dictionary containing the JSON response from Gemini
131
+
132
+ Returns:
133
+ Intent object with parsed data and validation status
134
+
135
+ Raises:
136
+ ValueError: If the response structure is invalid
137
+ json.JSONDecodeError: If response_data is a string and not valid JSON
138
+ """
139
+ try:
140
+ # Parse the Gemini response
141
+ gemini_response = GeminiResponse.from_json(response_data)
142
+
143
+ # Validate the response structure
144
+ try:
145
+ gemini_response.validate()
146
+ except ValueError as e:
147
+ logger.warning(f"Response validation warning: {e}")
148
+ # Continue anyway - validation warnings are not fatal
149
+
150
+ # For manipulation intents, validate the command
151
+ if gemini_response.type == ResponseType.MANIPULATION:
152
+ is_valid = self.validate_command(gemini_response.command)
153
+
154
+ if not is_valid:
155
+ logger.warning(
156
+ f"Invalid manipulation command: '{gemini_response.command}'. "
157
+ f"Not in trained task set."
158
+ )
159
+ validation_error = (
160
+ f"Command '{gemini_response.command}' is not in the trained task set. "
161
+ f"Valid commands are: {', '.join(self.valid_commands)}"
162
+ )
163
+ return Intent.from_gemini_response(
164
+ gemini_response,
165
+ is_valid=False,
166
+ validation_error=validation_error
167
+ )
168
+ else:
169
+ logger.info(f"Valid manipulation command: '{gemini_response.command}'")
170
+
171
+ # For conversation intents, always valid (gestures are predefined)
172
+ else:
173
+ logger.info(f"Conversation intent with gesture: {gemini_response.gesture.value}")
174
+
175
+ # Create and return valid intent
176
+ return Intent.from_gemini_response(gemini_response, is_valid=True)
177
+
178
+ except (ValueError, KeyError) as e:
179
+ logger.error(f"Failed to parse Gemini response: {e}")
180
+ raise ValueError(f"Invalid Gemini response structure: {e}")
181
+
182
+ def parse_gemini_response_string(self, response_string: str) -> Intent:
183
+ """
184
+ Parse a Gemini API response from a JSON string.
185
+
186
+ Args:
187
+ response_string: JSON string containing the Gemini response
188
+
189
+ Returns:
190
+ Intent object with parsed data and validation status
191
+
192
+ Raises:
193
+ json.JSONDecodeError: If the string is not valid JSON
194
+ ValueError: If the response structure is invalid
195
+ """
196
+ try:
197
+ response_data = json.loads(response_string)
198
+ except json.JSONDecodeError as e:
199
+ logger.error(f"Failed to parse JSON string: {e}")
200
+ raise
201
+
202
+ return self.parse_gemini_response(response_data)
203
+
204
+ def validate_command(self, command: str) -> bool:
205
+ """
206
+ Validate that a manipulation command is in the trained task set.
207
+
208
+ This performs exact string matching against the list of valid commands.
209
+ Commands must match exactly (case-sensitive) to be considered valid.
210
+
211
+ Args:
212
+ command: The manipulation command string to validate
213
+
214
+ Returns:
215
+ True if the command is valid, False otherwise
216
+ """
217
+ if not command or not isinstance(command, str):
218
+ logger.warning(f"Invalid command type: {type(command)}")
219
+ return False
220
+
221
+ # Exact match required
222
+ is_valid = command in self.valid_commands
223
+
224
+ if not is_valid:
225
+ # Log for debugging - maybe it's close to a valid command
226
+ logger.debug(f"Command '{command}' not found in valid commands")
227
+ logger.debug(f"Valid commands: {self.valid_commands}")
228
+
229
+ return is_valid
230
+
231
+ def get_valid_commands(self) -> List[str]:
232
+ """
233
+ Get the list of valid manipulation commands.
234
+
235
+ Returns:
236
+ List of valid command strings
237
+ """
238
+ return self.valid_commands.copy()
239
+
240
+ def add_valid_command(self, command: str) -> None:
241
+ """
242
+ Add a new valid manipulation command to the router.
243
+
244
+ This is useful when training new tasks and expanding the command set.
245
+
246
+ Args:
247
+ command: The new command string to add
248
+ """
249
+ if command not in self.valid_commands:
250
+ self.valid_commands.append(command)
251
+ logger.info(f"Added new valid command: '{command}'")
252
+ else:
253
+ logger.warning(f"Command already exists: '{command}'")
254
+
255
+ def remove_valid_command(self, command: str) -> bool:
256
+ """
257
+ Remove a valid manipulation command from the router.
258
+
259
+ Args:
260
+ command: The command string to remove
261
+
262
+ Returns:
263
+ True if the command was removed, False if it wasn't found
264
+ """
265
+ if command in self.valid_commands:
266
+ self.valid_commands.remove(command)
267
+ logger.info(f"Removed valid command: '{command}'")
268
+ return True
269
+ else:
270
+ logger.warning(f"Command not found: '{command}'")
271
+ return False
272
+
273
+ def route_intent(self, intent: Intent) -> str:
274
+ """
275
+ Determine the execution path for an intent.
276
+
277
+ Args:
278
+ intent: The Intent object to route
279
+
280
+ Returns:
281
+ String indicating the execution path: "gesture", "manipulation", or "invalid"
282
+ """
283
+ if not intent.is_valid:
284
+ logger.warning(f"Invalid intent: {intent.validation_error}")
285
+ return "invalid"
286
+
287
+ if intent.type == ResponseType.CONVERSATION:
288
+ logger.info(f"Routing to gesture execution: {intent.gesture}")
289
+ return "gesture"
290
+ elif intent.type == ResponseType.MANIPULATION:
291
+ logger.info(f"Routing to manipulation execution: {intent.command}")
292
+ return "manipulation"
293
+ else:
294
+ logger.error(f"Unknown intent type: {intent.type}")
295
+ return "invalid"
src/mortis/lerobot_async_client.py ADDED
@@ -0,0 +1,668 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LeRobot async inference client wrapper for Mortis manipulation tasks.
3
+
4
+ This module provides a high-level interface to LeRobot's async inference system
5
+ (PolicyServer + RobotClient) for executing SmolVLA manipulation tasks while
6
+ keeping the Gradio UI responsive.
7
+
8
+ Architecture:
9
+ - PolicyServer: Runs in a separate thread, loads SmolVLA model, performs inference
10
+ - RobotClient: Controls the SO101 robot, captures observations, executes actions
11
+ - This wrapper: Manages lifecycle and provides simple API for Mortis
12
+ """
13
+
14
+ import logging
15
+ import threading
16
+ import time
17
+ from dataclasses import dataclass
18
+ from enum import Enum
19
+ from pathlib import Path
20
+ from typing import Optional, Dict, Any, Callable
21
+
22
+ from lerobot.robots.so101_follower import SO101FollowerConfig
23
+ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
24
+ from lerobot.cameras.realsense import RealSenseCameraConfig
25
+ from lerobot.async_inference.configs import PolicyServerConfig, RobotClientConfig
26
+ from lerobot.async_inference.policy_server import serve
27
+ from lerobot.async_inference.robot_client import RobotClient
28
+
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ class ManipulationStatus(Enum):
34
+ """Status of a manipulation task execution."""
35
+ IDLE = "idle"
36
+ STARTING = "starting"
37
+ RUNNING = "running"
38
+ COMPLETE = "complete"
39
+ FAILED = "failed"
40
+ STOPPED = "stopped"
41
+
42
+
43
+ @dataclass
44
+ class ManipulationTask:
45
+ """
46
+ Represents a manipulation task for LeRobot async execution.
47
+
48
+ Attributes:
49
+ task: Natural language task description
50
+ max_steps: Maximum number of action steps to execute
51
+ started_at: Timestamp when task started
52
+ completed_at: Timestamp when task completed
53
+ status: Current task status
54
+ error: Error message if task failed
55
+ """
56
+ task: str
57
+ max_steps: int = 1000 # At 30fps, ~33 seconds of execution
58
+ started_at: Optional[float] = None
59
+ completed_at: Optional[float] = None
60
+ status: ManipulationStatus = ManipulationStatus.IDLE
61
+ error: Optional[str] = None
62
+
63
+ @property
64
+ def duration(self) -> Optional[float]:
65
+ """Get task execution duration in seconds."""
66
+ if self.started_at and self.completed_at:
67
+ return self.completed_at - self.started_at
68
+ return None
69
+
70
+
71
+ class LeRobotAsyncClient:
72
+ """
73
+ High-level wrapper for LeRobot async inference system.
74
+
75
+ This class manages the PolicyServer and RobotClient lifecycle, providing
76
+ a simple interface for executing manipulation tasks asynchronously.
77
+
78
+ Usage:
79
+ # Create client
80
+ client = LeRobotAsyncClient(
81
+ robot_port="/dev/ttyACM1",
82
+ model_path="jlamperez/kiroween-potion-smolvla",
83
+ camera_configs={...}
84
+ )
85
+
86
+ # Start the system
87
+ client.start()
88
+
89
+ # Execute a task
90
+ client.execute_task("Pick up the skull and place it in the green cup")
91
+
92
+ # Check status
93
+ status = client.get_status()
94
+
95
+ # Stop when done
96
+ client.stop()
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ robot_port: str = "/dev/ttyACM1",
102
+ robot_id: str = "my_follower_robot_arm", # Must match calibration file name
103
+ model_path: str = "jlamperez/kiroween-potion-smolvla",
104
+ policy_device: str = "cuda",
105
+ camera_configs: Optional[Dict[str, Any]] = None,
106
+ server_host: str = "127.0.0.1",
107
+ server_port: int = 8080,
108
+ actions_per_chunk: int = 50,
109
+ chunk_size_threshold: float = 0.5,
110
+ aggregate_fn_name: str = "weighted_average",
111
+ ):
112
+ """
113
+ Initialize the LeRobot async client.
114
+
115
+ Args:
116
+ robot_port: Serial port for SO101 robot (e.g., "/dev/ttyACM1")
117
+ robot_id: Identifier for the robot
118
+ model_path: HuggingFace model path or local checkpoint
119
+ policy_device: Device for model inference ("cuda" or "cpu")
120
+ camera_configs: Dictionary of camera configurations
121
+ server_host: PolicyServer host address
122
+ server_port: PolicyServer port
123
+ actions_per_chunk: Number of actions per inference chunk
124
+ chunk_size_threshold: Threshold for action chunk aggregation
125
+ aggregate_fn_name: Function name for aggregating action chunks
126
+ """
127
+ self.robot_port = robot_port
128
+ self.robot_id = robot_id
129
+ self.model_path = model_path
130
+ self.policy_device = policy_device
131
+ self.server_host = server_host
132
+ self.server_port = server_port
133
+ self.actions_per_chunk = actions_per_chunk
134
+ self.chunk_size_threshold = chunk_size_threshold
135
+ self.aggregate_fn_name = aggregate_fn_name
136
+
137
+ # Use default camera configs if not provided
138
+ self.camera_configs = camera_configs or self._get_default_camera_configs()
139
+
140
+ # Server and client instances
141
+ self.server_thread: Optional[threading.Thread] = None
142
+ self.robot_client: Optional[RobotClient] = None
143
+ self.action_receiver_thread: Optional[threading.Thread] = None
144
+ self.control_thread: Optional[threading.Thread] = None
145
+
146
+ # Current task tracking
147
+ self.current_task: Optional[ManipulationTask] = None
148
+ self._running = False
149
+ self._stop_event = threading.Event()
150
+ self._task_stop_event = threading.Event() # Event to signal task cancellation
151
+ self._idle_callback: Optional[Callable] = None # Callback to move robot to idle
152
+
153
+ logger.info(f"LeRobotAsyncClient initialized with model: {model_path}")
154
+
155
+ def _get_default_camera_configs(self) -> Dict[str, Any]:
156
+ """
157
+ Get default camera configuration for Mortis setup.
158
+
159
+ IMPORTANT: This configuration MUST match the cameras used during training!
160
+ If you trained with IntelRealSense + OpenCV, use the same setup here.
161
+
162
+ Returns:
163
+ Dictionary of camera configurations
164
+ """
165
+ # Default camera configuration matching training setup
166
+ # This should match your training configuration exactly!
167
+
168
+ # Configuration with RealSense + OpenCV (matches training setup)
169
+ return {
170
+ "camera1": RealSenseCameraConfig(
171
+ serial_number_or_name="030522070314",
172
+ width=640,
173
+ height=480,
174
+ fps=30
175
+ ),
176
+ "camera2": OpenCVCameraConfig(
177
+ index_or_path=8,
178
+ width=640,
179
+ height=480,
180
+ fps=30
181
+ )
182
+ }
183
+
184
+ def start(self) -> bool:
185
+ """
186
+ Start the PolicyServer only.
187
+
188
+ The RobotClient will be created lazily when the first task is executed.
189
+ This avoids loading the model unnecessarily at startup.
190
+
191
+ Returns:
192
+ True if startup successful, False otherwise
193
+ """
194
+ if self._running:
195
+ logger.warning("LeRobotAsyncClient is already running")
196
+ return True
197
+
198
+ try:
199
+ logger.info("Starting PolicyServer...")
200
+
201
+ # Configure and start PolicyServer
202
+ server_config = PolicyServerConfig(
203
+ host=self.server_host,
204
+ port=self.server_port
205
+ )
206
+
207
+ self.server_thread = threading.Thread(
208
+ target=serve,
209
+ args=(server_config,),
210
+ daemon=True,
211
+ name="PolicyServer"
212
+ )
213
+ self.server_thread.start()
214
+
215
+ # Give server time to start
216
+ time.sleep(2.0)
217
+ logger.info(f"PolicyServer started on {self.server_host}:{self.server_port}")
218
+
219
+ self._running = True
220
+ self._stop_event.clear()
221
+
222
+ logger.info("LeRobotAsyncClient started (RobotClient will be created on first task)")
223
+ return True
224
+
225
+ except Exception as e:
226
+ logger.error(f"Failed to start LeRobotAsyncClient: {e}", exc_info=True)
227
+ self.stop()
228
+ return False
229
+
230
+ def stop(self) -> None:
231
+ """
232
+ Stop the PolicyServer and RobotClient.
233
+
234
+ This method gracefully shuts down all components.
235
+ """
236
+ if not self._running:
237
+ logger.warning("LeRobotAsyncClient is not running")
238
+ return
239
+
240
+ logger.info("Stopping LeRobotAsyncClient...")
241
+
242
+ self._running = False
243
+ self._stop_event.set()
244
+
245
+ # Stop control thread if running
246
+ if self.control_thread and self.control_thread.is_alive():
247
+ logger.info("Waiting for control thread to finish...")
248
+ self.control_thread.join(timeout=5.0)
249
+
250
+ # Stop robot client
251
+ if self.robot_client:
252
+ try:
253
+ self.robot_client.stop()
254
+ logger.info("RobotClient stopped")
255
+ except Exception as e:
256
+ logger.error(f"Error stopping RobotClient: {e}")
257
+
258
+ # Action receiver thread should stop automatically (daemon)
259
+ # Server thread should stop automatically (daemon)
260
+
261
+ self.robot_client = None
262
+ self.server_thread = None
263
+ self.action_receiver_thread = None
264
+ self.control_thread = None
265
+
266
+ logger.info("LeRobotAsyncClient stopped")
267
+
268
+ def execute_task(
269
+ self,
270
+ task: str,
271
+ max_steps: int = 1000,
272
+ blocking: bool = False,
273
+ timeout: float = 60.0
274
+ ) -> bool:
275
+ """
276
+ Execute a manipulation task asynchronously.
277
+
278
+ This method stops any running task and creates a fresh RobotClient
279
+ for the new task, ensuring clean state.
280
+
281
+ Args:
282
+ task: Natural language task description
283
+ max_steps: Maximum number of action steps
284
+ blocking: If True, wait for task to complete before returning
285
+ timeout: Maximum execution time in seconds (default: 60.0)
286
+
287
+ Returns:
288
+ True if task started successfully, False otherwise
289
+ """
290
+ if not self._running:
291
+ logger.error("Cannot execute task: client not running")
292
+ return False
293
+
294
+ # Always need a fresh client for each task because control_loop can only run once
295
+ # But we keep the PolicyServer alive so the model stays loaded
296
+ need_new_client = True
297
+
298
+ if self.robot_client is None:
299
+ # First task - need to create client
300
+ logger.info("First task - creating RobotClient...")
301
+ elif self.current_task and self.current_task.status == ManipulationStatus.RUNNING:
302
+ # Task is running - stop it first
303
+ logger.info(f"Stopping previous task: {self.current_task.task}")
304
+ self._stop_robot_client()
305
+ else:
306
+ # Previous task finished - recreate client for new task
307
+ logger.info("Recreating RobotClient for new task (PolicyServer keeps model loaded)")
308
+
309
+ # Wait for previous control thread to finish
310
+ if self.control_thread and self.control_thread.is_alive():
311
+ logger.info("Waiting for previous control thread to finish...")
312
+ self.control_thread.join(timeout=3.0)
313
+ if self.control_thread.is_alive():
314
+ logger.warning("Previous control thread still running, proceeding anyway")
315
+
316
+ # Create new task
317
+ self.current_task = ManipulationTask(
318
+ task=task,
319
+ max_steps=max_steps,
320
+ status=ManipulationStatus.STARTING
321
+ )
322
+
323
+ # Clear any previous stop signal
324
+ self._task_stop_event.clear()
325
+
326
+ logger.info(f"Executing task: {task}")
327
+ logger.info(f"Limits: max_steps={max_steps}, timeout={timeout}s")
328
+
329
+ # Create/recreate robot client only if needed
330
+ if need_new_client:
331
+ if not self._recreate_robot_client(task):
332
+ logger.error("Failed to create robot client")
333
+ self.current_task.status = ManipulationStatus.FAILED
334
+ self.current_task.error = "Failed to initialize robot client"
335
+ return False
336
+
337
+ # Start control loop in separate thread
338
+ self.control_thread = threading.Thread(
339
+ target=self._run_control_loop,
340
+ args=(task, max_steps, timeout),
341
+ daemon=True,
342
+ name="ControlLoop"
343
+ )
344
+ self.control_thread.start()
345
+
346
+ if blocking:
347
+ self.control_thread.join()
348
+
349
+ return True
350
+
351
+ def _stop_robot_client(self) -> None:
352
+ """
353
+ Stop the robot client cleanly.
354
+
355
+ This stops the robot client and waits for threads to finish.
356
+ """
357
+ if self.robot_client:
358
+ try:
359
+ logger.info("Stopping robot client...")
360
+ self.robot_client.stop()
361
+
362
+ # Wait for action receiver thread
363
+ if self.action_receiver_thread and self.action_receiver_thread.is_alive():
364
+ self.action_receiver_thread.join(timeout=2.0)
365
+
366
+ logger.info("Robot client stopped")
367
+ except Exception as e:
368
+ logger.error(f"Error stopping robot client: {e}")
369
+
370
+ def _recreate_robot_client(self, task: str) -> bool:
371
+ """
372
+ Recreate the robot client with a new task.
373
+
374
+ This creates a fresh RobotClient instance for the new task,
375
+ ensuring clean state.
376
+
377
+ Args:
378
+ task: Task description for the new client
379
+
380
+ Returns:
381
+ True if successful, False otherwise
382
+ """
383
+ try:
384
+ # Stop existing client if any
385
+ self._stop_robot_client()
386
+
387
+ # Small delay to ensure port is released
388
+ time.sleep(0.5)
389
+
390
+ # Reconfigure robot
391
+ from pathlib import Path
392
+ from lerobot.robots.so101_follower import SO101FollowerConfig
393
+ from lerobot.async_inference.configs import RobotClientConfig
394
+ from lerobot.async_inference.robot_client import RobotClient
395
+
396
+ calibration_dir = Path(".cache/calibration/so101")
397
+ robot_config = SO101FollowerConfig(
398
+ port=self.robot_port,
399
+ id=self.robot_id,
400
+ cameras=self.camera_configs,
401
+ calibration_dir=calibration_dir
402
+ )
403
+
404
+ client_config = RobotClientConfig(
405
+ robot=robot_config,
406
+ server_address=f"{self.server_host}:{self.server_port}",
407
+ policy_device=self.policy_device,
408
+ policy_type="smolvla",
409
+ pretrained_name_or_path=self.model_path,
410
+ chunk_size_threshold=self.chunk_size_threshold,
411
+ actions_per_chunk=self.actions_per_chunk,
412
+ aggregate_fn_name=self.aggregate_fn_name,
413
+ task=task # Set the task in the config
414
+ )
415
+
416
+ # Create new robot client
417
+ self.robot_client = RobotClient(client_config)
418
+
419
+ if not self.robot_client.start():
420
+ raise RuntimeError("Failed to start RobotClient")
421
+
422
+ # Start action receiver thread
423
+ self.action_receiver_thread = threading.Thread(
424
+ target=self.robot_client.receive_actions,
425
+ daemon=True,
426
+ name="ActionReceiver"
427
+ )
428
+ self.action_receiver_thread.start()
429
+
430
+ logger.info("Robot client recreated successfully")
431
+ return True
432
+
433
+ except Exception as e:
434
+ logger.error(f"Failed to recreate robot client: {e}", exc_info=True)
435
+ return False
436
+
437
+ def stop_current_task(self) -> bool:
438
+ """
439
+ Stop the currently running task by stopping the robot client.
440
+
441
+ This cleanly stops the robot client, which will cause the control
442
+ loop to exit. The client will be recreated for the next task.
443
+
444
+ Returns:
445
+ True if task was stopped successfully
446
+ """
447
+ if not self.current_task or self.current_task.status != ManipulationStatus.RUNNING:
448
+ logger.warning("No task currently running to stop")
449
+ return False
450
+
451
+ logger.info("Stopping current task...")
452
+
453
+ try:
454
+ # Mark task as stopped
455
+ self.current_task.status = ManipulationStatus.STOPPED
456
+ self.current_task.completed_at = time.time()
457
+ self.current_task.error = "Task stopped by user"
458
+
459
+ # Signal task stop
460
+ self._task_stop_event.set()
461
+
462
+ # Stop the robot client (this will interrupt the control loop)
463
+ try:
464
+ self._stop_robot_client()
465
+ except Exception as e:
466
+ logger.warning(f"Error stopping client (expected): {e}")
467
+
468
+ # Move robot to idle position
469
+ if self._idle_callback:
470
+ logger.info("Moving robot to idle position...")
471
+ try:
472
+ self._idle_callback()
473
+ logger.info("Robot moved to idle position")
474
+ except Exception as e:
475
+ logger.error(f"Failed to move to idle: {e}")
476
+
477
+ logger.info("Task stopped successfully")
478
+
479
+ # Clear the task after a delay
480
+ def clear_task():
481
+ time.sleep(3.0)
482
+ if self.current_task and self.current_task.status == ManipulationStatus.STOPPED:
483
+ self.current_task = None
484
+ logger.info("Cleared stopped task from status")
485
+
486
+ clear_thread = threading.Thread(target=clear_task, daemon=True)
487
+ clear_thread.start()
488
+
489
+ return True
490
+
491
+ except Exception as e:
492
+ logger.error(f"Failed to stop task: {e}", exc_info=True)
493
+ return False
494
+
495
+ def _run_control_loop(self, task: str, max_steps: int, timeout: float) -> None:
496
+ """
497
+ Run the control loop for task execution with timeout.
498
+
499
+ This runs in a separate thread and executes the task using
500
+ the RobotClient's control_loop method. The timeout will stop
501
+ the task, and recreating the client for each task ensures clean state.
502
+
503
+ Note: max_steps is not directly enforced by LeRobot's control_loop,
504
+ but the timeout provides a time-based limit.
505
+
506
+ Args:
507
+ task: Task description
508
+ max_steps: Maximum steps (informational, not enforced)
509
+ timeout: Maximum execution time in seconds (default: 60.0)
510
+ """
511
+ if not self.current_task:
512
+ return
513
+
514
+ try:
515
+ self.current_task.status = ManipulationStatus.RUNNING
516
+ self.current_task.started_at = time.time()
517
+
518
+ logger.info(f"Starting control loop for: {task}")
519
+ logger.info(f"Timeout: {timeout}s (max_steps={max_steps} is informational)")
520
+
521
+ # Clear task stop event
522
+ self._task_stop_event.clear()
523
+
524
+ # Run control_loop in a separate thread so we can timeout
525
+ control_thread = threading.Thread(
526
+ target=lambda: self.robot_client.control_loop(task=task, verbose=False),
527
+ daemon=True,
528
+ name="ControlLoopInner"
529
+ )
530
+ control_thread.start()
531
+
532
+ # Wait for completion or timeout
533
+ control_thread.join(timeout=timeout)
534
+
535
+ # Check if thread is still alive (timeout occurred)
536
+ if control_thread.is_alive():
537
+ logger.warning(f"Task timed out after {timeout}s")
538
+
539
+ # Mark task as stopped first
540
+ self.current_task.status = ManipulationStatus.STOPPED
541
+ self.current_task.completed_at = time.time()
542
+ self.current_task.error = f"Task exceeded timeout of {timeout}s"
543
+
544
+ # Signal stop event
545
+ self._task_stop_event.set()
546
+
547
+ # Stop the robot client to interrupt the control loop
548
+ # This will cause the control thread to error out, but we catch it
549
+ logger.info("Stopping robot client to interrupt control loop...")
550
+ try:
551
+ self._stop_robot_client()
552
+ except Exception as e:
553
+ logger.warning(f"Error stopping client (expected): {e}")
554
+
555
+ # Wait a bit for thread to die
556
+ control_thread.join(timeout=2.0)
557
+
558
+ logger.info("Task stopped due to timeout")
559
+
560
+ # Move robot to idle position using callback if provided
561
+ if hasattr(self, '_idle_callback') and self._idle_callback:
562
+ logger.info("Moving robot to idle position...")
563
+ try:
564
+ self._idle_callback()
565
+ logger.info("Robot moved to idle position")
566
+ except Exception as e:
567
+ logger.error(f"Failed to move to idle: {e}")
568
+
569
+ # Clear the task after a delay so UI can show the stopped status
570
+ def clear_task():
571
+ time.sleep(3.0) # Show stopped status for 3 seconds
572
+ if self.current_task and self.current_task.status == ManipulationStatus.STOPPED:
573
+ self.current_task = None
574
+ logger.info("Cleared stopped task from status")
575
+
576
+ clear_thread = threading.Thread(target=clear_task, daemon=True)
577
+ clear_thread.start()
578
+
579
+ else:
580
+ # Task completed successfully
581
+ self.current_task.status = ManipulationStatus.COMPLETE
582
+ self.current_task.completed_at = time.time()
583
+ logger.info(f"Task completed in {self.current_task.duration:.2f}s")
584
+
585
+ # Clear completed task after showing status
586
+ def clear_task():
587
+ time.sleep(3.0) # Show completed status for 3 seconds
588
+ if self.current_task and self.current_task.status == ManipulationStatus.COMPLETE:
589
+ self.current_task = None
590
+ logger.info("Cleared completed task from status")
591
+
592
+ clear_thread = threading.Thread(target=clear_task, daemon=True)
593
+ clear_thread.start()
594
+
595
+ except KeyboardInterrupt:
596
+ logger.info("Task interrupted by user")
597
+ self.current_task.status = ManipulationStatus.STOPPED
598
+ self.current_task.completed_at = time.time()
599
+
600
+ except Exception as e:
601
+ logger.error(f"Task failed: {e}", exc_info=True)
602
+ self.current_task.status = ManipulationStatus.FAILED
603
+ self.current_task.error = str(e)
604
+ self.current_task.completed_at = time.time()
605
+
606
+ def get_status(self) -> ManipulationStatus:
607
+ """
608
+ Get the current task status.
609
+
610
+ Returns:
611
+ Current ManipulationStatus
612
+ """
613
+ if self.current_task:
614
+ return self.current_task.status
615
+ return ManipulationStatus.IDLE
616
+
617
+ def get_current_task(self) -> Optional[ManipulationTask]:
618
+ """
619
+ Get the currently executing task.
620
+
621
+ Returns:
622
+ Current ManipulationTask or None if idle
623
+ """
624
+ return self.current_task
625
+
626
+ def is_busy(self) -> bool:
627
+ """
628
+ Check if a task is currently executing.
629
+
630
+ Returns:
631
+ True if a task is running
632
+ """
633
+ return (
634
+ self.current_task is not None and
635
+ self.current_task.status == ManipulationStatus.RUNNING
636
+ )
637
+
638
+ def is_running(self) -> bool:
639
+ """
640
+ Check if the client is running (server and robot connected).
641
+
642
+ Returns:
643
+ True if client is running
644
+ """
645
+ return self._running
646
+
647
+ def set_idle_callback(self, callback: Callable) -> None:
648
+ """
649
+ Set a callback function to move the robot to idle position.
650
+
651
+ This callback will be called when a task times out, to safely
652
+ return the robot to a neutral position.
653
+
654
+ Args:
655
+ callback: Function to call (e.g., lambda: mortis_arm.move_arm("idle"))
656
+ """
657
+ self._idle_callback = callback
658
+ logger.info("Idle callback configured")
659
+
660
+ def __enter__(self):
661
+ """Context manager entry: start the client."""
662
+ self.start()
663
+ return self
664
+
665
+ def __exit__(self, exc_type, exc_val, exc_tb):
666
+ """Context manager exit: stop the client."""
667
+ self.stop()
668
+ return False
src/mortis/models.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data models for Gemini API responses and intent routing.
3
+
4
+ This module defines the structured data types used throughout the Mortis system
5
+ for parsing Gemini responses, routing intents, and managing execution tasks.
6
+ """
7
+
8
+ import json
9
+ from dataclasses import dataclass
10
+ from enum import Enum
11
+ from typing import Optional, Dict, Any
12
+
13
+
14
+ class ResponseType(Enum):
15
+ """Type of response from Gemini API."""
16
+ CONVERSATION = "conversation"
17
+ MANIPULATION = "manipulation"
18
+
19
+
20
+ class Mood(Enum):
21
+ """Emotional mood for Mortis character responses."""
22
+ OMINOUS = "ominous"
23
+ PLAYFUL = "playful"
24
+ ANGRY = "angry"
25
+ NERVOUS = "nervous"
26
+ TRIUMPHANT = "triumphant"
27
+ MISCHIEVOUS = "mischievous"
28
+ SINISTER = "sinister"
29
+ CURIOUS = "curious"
30
+ NEUTRAL = "neutral"
31
+
32
+
33
+ class Gesture(Enum):
34
+ """Available gesture actions for the SO101 robotic arm."""
35
+ IDLE = "idle"
36
+ WAVE = "wave"
37
+ POINT_LEFT = "point_left"
38
+ POINT_RIGHT = "point_right"
39
+ GRAB = "grab"
40
+ DROP = "drop"
41
+
42
+
43
+ @dataclass
44
+ class GeminiResponse:
45
+ """
46
+ Structured response from Gemini API.
47
+
48
+ Attributes:
49
+ type: Whether this is a conversation or manipulation response
50
+ message: The text message to display/speak to the user
51
+ mood: The emotional mood of the response
52
+ gesture: Optional gesture to execute (for conversation type)
53
+ command: Optional manipulation command (for manipulation type)
54
+ """
55
+ type: ResponseType
56
+ message: str
57
+ mood: Mood
58
+ gesture: Optional[Gesture] = None
59
+ command: Optional[str] = None
60
+
61
+ @classmethod
62
+ def from_json(cls, json_data: Dict[str, Any]) -> "GeminiResponse":
63
+ """
64
+ Parse a GeminiResponse from JSON data returned by Gemini API.
65
+
66
+ Args:
67
+ json_data: Dictionary containing the JSON response from Gemini
68
+
69
+ Returns:
70
+ GeminiResponse object with validated fields
71
+
72
+ Raises:
73
+ ValueError: If required fields are missing or invalid
74
+ KeyError: If JSON structure is malformed
75
+ """
76
+ # Validate required fields
77
+ if "type" not in json_data:
78
+ raise ValueError("Missing required field: 'type'")
79
+ if "message" not in json_data:
80
+ raise ValueError("Missing required field: 'message'")
81
+ if "mood" not in json_data:
82
+ raise ValueError("Missing required field: 'mood'")
83
+
84
+ # Parse response type
85
+ try:
86
+ response_type = ResponseType(json_data["type"])
87
+ except ValueError:
88
+ raise ValueError(f"Invalid response type: {json_data['type']}. Must be 'conversation' or 'manipulation'")
89
+
90
+ # Parse mood
91
+ try:
92
+ mood = Mood(json_data["mood"])
93
+ except ValueError:
94
+ raise ValueError(f"Invalid mood: {json_data['mood']}. Must be one of: {[m.value for m in Mood]}")
95
+
96
+ # Parse optional fields based on response type
97
+ gesture = None
98
+ command = None
99
+
100
+ if response_type == ResponseType.CONVERSATION:
101
+ # Conversation responses should have a gesture
102
+ if "gesture" in json_data:
103
+ try:
104
+ gesture = Gesture(json_data["gesture"])
105
+ except ValueError:
106
+ raise ValueError(f"Invalid gesture: {json_data['gesture']}. Must be one of: {[g.value for g in Gesture]}")
107
+ else:
108
+ # Default to idle if no gesture specified
109
+ gesture = Gesture.IDLE
110
+
111
+ elif response_type == ResponseType.MANIPULATION:
112
+ # Manipulation responses must have a command
113
+ if "command" not in json_data:
114
+ raise ValueError("Manipulation responses must include 'command' field")
115
+ command = json_data["command"]
116
+ if not isinstance(command, str) or not command.strip():
117
+ raise ValueError("Command must be a non-empty string")
118
+
119
+ # Validate message
120
+ message = json_data["message"]
121
+ if not isinstance(message, str) or not message.strip():
122
+ raise ValueError("Message must be a non-empty string")
123
+
124
+ return cls(
125
+ type=response_type,
126
+ message=message,
127
+ mood=mood,
128
+ gesture=gesture,
129
+ command=command
130
+ )
131
+
132
+ @classmethod
133
+ def from_json_string(cls, json_string: str) -> "GeminiResponse":
134
+ """
135
+ Parse a GeminiResponse from a JSON string.
136
+
137
+ Args:
138
+ json_string: JSON string containing the Gemini response
139
+
140
+ Returns:
141
+ GeminiResponse object with validated fields
142
+
143
+ Raises:
144
+ json.JSONDecodeError: If the string is not valid JSON
145
+ ValueError: If required fields are missing or invalid
146
+ """
147
+ try:
148
+ json_data = json.loads(json_string)
149
+ except json.JSONDecodeError as e:
150
+ raise json.JSONDecodeError(f"Invalid JSON string: {e.msg}", e.doc, e.pos)
151
+
152
+ return cls.from_json(json_data)
153
+
154
+ def validate(self) -> bool:
155
+ """
156
+ Validate the response structure and content.
157
+
158
+ Returns:
159
+ True if the response is valid
160
+
161
+ Raises:
162
+ ValueError: If validation fails
163
+ """
164
+ # Check message length constraints (per product requirements)
165
+ if len(self.message) > 120:
166
+ raise ValueError(f"Message exceeds 120 characters: {len(self.message)} chars")
167
+
168
+ word_count = len(self.message.split())
169
+ if word_count > 30:
170
+ raise ValueError(f"Message exceeds 30 words: {word_count} words")
171
+
172
+ # Validate type-specific requirements
173
+ if self.type == ResponseType.CONVERSATION:
174
+ if self.gesture is None:
175
+ raise ValueError("Conversation responses must have a gesture")
176
+ if self.command is not None:
177
+ raise ValueError("Conversation responses should not have a command")
178
+
179
+ elif self.type == ResponseType.MANIPULATION:
180
+ if self.command is None or not self.command.strip():
181
+ raise ValueError("Manipulation responses must have a non-empty command")
182
+ if self.gesture is not None:
183
+ raise ValueError("Manipulation responses should not have a gesture")
184
+
185
+ return True
186
+
187
+ def to_dict(self) -> Dict[str, Any]:
188
+ """
189
+ Convert the response to a dictionary.
190
+
191
+ Returns:
192
+ Dictionary representation of the response
193
+ """
194
+ result = {
195
+ "type": self.type.value,
196
+ "message": self.message,
197
+ "mood": self.mood.value,
198
+ }
199
+
200
+ if self.gesture is not None:
201
+ result["gesture"] = self.gesture.value
202
+
203
+ if self.command is not None:
204
+ result["command"] = self.command
205
+
206
+ return result
207
+
208
+ def to_json(self) -> str:
209
+ """
210
+ Convert the response to a JSON string.
211
+
212
+ Returns:
213
+ JSON string representation of the response
214
+ """
215
+ return json.dumps(self.to_dict(), indent=2)
src/mortis/robot.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ({"shoulder_pan.pos": -45, "shoulder_lift.pos": -99, "elbow_flex.pos": 0, "wrist_flex.pos": 60, "wrist_roll.pos": 0, "gripper.pos": 60}, 0.5),
2
+ import logging
3
+ import os
4
+ import time
5
+ from pathlib import Path
6
+
7
+ from lerobot.robots.so101_follower import SO101Follower, SO101FollowerConfig
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ HOME_POSE = {
12
+ "shoulder_pan.pos": 0,
13
+ "shoulder_lift.pos": -99,
14
+ "elbow_flex.pos": 97,
15
+ "wrist_flex.pos": 55,
16
+ "wrist_roll.pos": 0,
17
+ "gripper.pos": 0,
18
+ }
19
+
20
+
21
+ GESTURES = {
22
+ "idle": [
23
+ (HOME_POSE, 1.0),
24
+ ],
25
+ "wave": [
26
+ ({"wrist_flex.pos": -40}, 0.5),
27
+ ({"shoulder_pan.pos": -5, "shoulder_lift.pos": 65, "elbow_flex.pos": -70}, 1),
28
+ ({"shoulder_lift.pos": 0, "elbow_flex.pos": 0}, 0.5),
29
+ ({"wrist_flex.pos": 0}, 0.5),
30
+ (HOME_POSE, 1.0),
31
+ ],
32
+ "point_left": [
33
+ ({"shoulder_pan.pos": -60, "shoulder_lift.pos": -30, "elbow_flex.pos": -15, "wrist_flex.pos": 42, "wrist_roll.pos": 0, "gripper.pos": 0}, 1),
34
+ ({"wrist_flex.pos": 80}, 0.5),
35
+ ({"wrist_flex.pos": 42}, 0.5),
36
+ ({"wrist_flex.pos": 80}, 0.5),
37
+ (HOME_POSE, 1.0),
38
+ ],
39
+ "point_right": [
40
+ ({"shoulder_pan.pos": 65, "shoulder_lift.pos": -50, "elbow_flex.pos": -5, "wrist_flex.pos": 55, "wrist_roll.pos": 0, "gripper.pos": 0}, 1),
41
+ ({"wrist_flex.pos": 90}, 0.5),
42
+ ({"wrist_flex.pos": 42}, 0.5),
43
+ ({"wrist_flex.pos": 90}, 0.5),
44
+ (HOME_POSE, 1.0),
45
+ ],
46
+ "grab": [
47
+ ({'shoulder_pan.pos': 0, 'shoulder_lift.pos': -2, 'elbow_flex.pos': -8., 'wrist_flex.pos': 55, 'wrist_roll.pos': 0, 'gripper.pos': 0}, 0.8),
48
+ ({"wrist_flex.pos": 80}, 0.5),
49
+ ({"wrist_roll.pos": -45, "gripper.pos": 40}, 1),
50
+ ({"elbow_flex.pos": 30}, 1),
51
+ ({"wrist_roll.pos": 45, "gripper.pos": 10}, 1),
52
+ ({"elbow_flex.pos": -20}, 1),
53
+ (HOME_POSE, 1.0),
54
+ ],
55
+ "drop": [
56
+ ({'shoulder_pan.pos': 0, 'shoulder_lift.pos': 5, 'elbow_flex.pos': 20., 'wrist_flex.pos': 55, 'wrist_roll.pos': 0, 'gripper.pos': 0}, 0.8),
57
+ ({"gripper.pos": 80}, 1),
58
+ ({"gripper.pos": 00}, 1),
59
+ (HOME_POSE, 1.0),
60
+ ],
61
+ }
62
+
63
+
64
+ class MortisArm:
65
+ """
66
+ Class to control the Mortis SO101 robotic arm.
67
+ Manages connection, disconnection, and gesture execution.
68
+
69
+ Supports two modes:
70
+ - physical: Connects to real robot hardware
71
+ - simulation: Simulates robot behavior without hardware
72
+ """
73
+
74
+ def __init__(self, port="/dev/ttyACM1", mode=None):
75
+ port = os.getenv("ROBOT_PORT", port)
76
+
77
+ # Determine mode: check env var or use provided mode
78
+ if mode is None:
79
+ mode = os.getenv("ROBOT_MODE", "physical").lower()
80
+
81
+ self.mode = mode
82
+ self.connected = False
83
+
84
+ if self.mode == "simulation":
85
+ logger.info("🎭 MortisArm initialized in SIMULATION mode (no physical robot)")
86
+ self.robot = None
87
+ self.connected = True # Always "connected" in simulation
88
+ else:
89
+ config = SO101FollowerConfig(
90
+ port=port,
91
+ id="my_follower_robot_arm",
92
+ calibration_dir=Path(".cache/calibration/so101/"),
93
+ )
94
+ self.robot = SO101Follower(config)
95
+ logger.info(f"🤖 MortisArm initialized in PHYSICAL mode on port {port}")
96
+
97
+ def connect(self):
98
+ """Connects to the robotic arm."""
99
+ if self.mode == "simulation":
100
+ logger.info("🎭 Simulation mode: skipping physical connection")
101
+ self.connected = True
102
+ return
103
+
104
+ if not self.connected:
105
+ try:
106
+ logger.info("Attempting to connect to robot arm...")
107
+ self.robot.connect()
108
+ self.connected = self.robot.is_connected
109
+ if self.connected:
110
+ logger.info("✅ Robot arm connected successfully")
111
+ # Move to the initial position to indicate it's ready
112
+ self.move_arm("idle")
113
+ else:
114
+ logger.warning("⚠️ Failed to establish connection to robot arm")
115
+ except Exception as e:
116
+ logger.error(f"❌ Connection error: {e}", exc_info=True)
117
+ self.connected = False
118
+
119
+ def disconnect(self):
120
+ """Disconnects the robotic arm."""
121
+ if self.mode == "simulation":
122
+ logger.info("🎭 Simulation mode: skipping physical disconnection")
123
+ self.connected = False
124
+ return
125
+
126
+ if self.connected:
127
+ logger.info("Disconnecting robot arm...")
128
+ # Move to rest position before disconnecting
129
+ self.move_arm("idle")
130
+ time.sleep(1)
131
+ self.robot.disconnect()
132
+ self.connected = False
133
+ logger.info("✅ Robot arm disconnected")
134
+
135
+ def move_arm(self, gesture_name: str):
136
+ """
137
+ Executes a sequence of movements (a gesture) by its name.
138
+ If the gesture does not exist, it executes 'idle'.
139
+ """
140
+ if not self.connected:
141
+ logger.warning("⚠️ Cannot execute gesture: robot arm not connected")
142
+ return
143
+
144
+ # If the gesture is not defined, return to the neutral position.
145
+ if gesture_name not in GESTURES:
146
+ logger.warning(f"⚠️ Unknown gesture '{gesture_name}', falling back to 'idle'")
147
+ gesture_name = "idle"
148
+
149
+ sequence = GESTURES[gesture_name]
150
+
151
+ if self.mode == "simulation":
152
+ # Simulation mode: just log the gesture
153
+ logger.info(f"🎭 [SIMULATION] Executing gesture '{gesture_name}' ({len(sequence)} steps)")
154
+
155
+ # Simulate timing by sleeping for total duration
156
+ total_delay = sum(delay for _, delay in sequence)
157
+ time.sleep(total_delay)
158
+
159
+ logger.info(f"🎭 [SIMULATION] Gesture '{gesture_name}' completed")
160
+ else:
161
+ # Physical mode: execute on real robot
162
+ logger.info(f"🤖 Executing gesture '{gesture_name}' ({len(sequence)} steps)")
163
+
164
+ for i, (action, delay) in enumerate(sequence, 1):
165
+ logger.debug(f"Gesture '{gesture_name}' step {i}/{len(sequence)}: {action}")
166
+ self.robot.send_action(action)
167
+ time.sleep(delay)
168
+
169
+ logger.info(f"✅ Gesture '{gesture_name}' completed")
170
+
171
+
172
+ if __name__ == "__main__":
173
+
174
+ mortis_arm = MortisArm()
175
+ if not mortis_arm.connected:
176
+ mortis_arm.connect()
177
+
178
+ mortis_arm.move_arm("drop")
179
+
180
+ mortis_arm.disconnect()
src/mortis/setup_dataset.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CLI tool for setting up Mortis dataset infrastructure.
4
+
5
+ This script initializes the dataset structure and generates
6
+ lerobot-record scripts for data collection.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import subprocess
12
+ import argparse
13
+ from pathlib import Path
14
+ from dotenv import load_dotenv
15
+
16
+ from mortis.data_collector import create_mortis_dataset, DataCollector
17
+
18
+
19
+ def check_huggingface_auth():
20
+ """Check if user is authenticated with Hugging Face."""
21
+ try:
22
+ result = subprocess.run(
23
+ ["huggingface-cli", "whoami"],
24
+ capture_output=True,
25
+ text=True,
26
+ timeout=5
27
+ )
28
+ return result.returncode == 0
29
+ except (subprocess.TimeoutExpired, FileNotFoundError):
30
+ return False
31
+
32
+
33
+ def main():
34
+ """Main entry point for dataset setup."""
35
+ # Parse command line arguments
36
+ parser = argparse.ArgumentParser(
37
+ description="Setup Mortis dataset infrastructure and generate recording scripts"
38
+ )
39
+ parser.add_argument(
40
+ "--dataset-name",
41
+ type=str,
42
+ default=None,
43
+ help="Name for the dataset (default: mortis_manipulation)"
44
+ )
45
+ parser.add_argument(
46
+ "--hf-user",
47
+ type=str,
48
+ default=None,
49
+ help="Hugging Face username (default: from HF_USER env var)"
50
+ )
51
+ args = parser.parse_args()
52
+
53
+ # Load environment variables from .env file
54
+ REPO_ROOT = Path(__file__).resolve().parents[2]
55
+ load_dotenv(REPO_ROOT / ".env")
56
+
57
+
58
+ print("="*70)
59
+ print("Mortis Dataset Setup")
60
+ print("="*70)
61
+ print()
62
+
63
+ # Check Hugging Face authentication
64
+ print("Checking Hugging Face authentication...")
65
+ if not check_huggingface_auth():
66
+ print("⚠️ Not logged in to Hugging Face")
67
+ print("📝 You need to authenticate before recording datasets")
68
+ print()
69
+ print("Run this command to login:")
70
+ print(" huggingface-cli login")
71
+ print()
72
+ print("Get your token from: https://huggingface.co/settings/tokens")
73
+ print()
74
+ response = input("Continue anyway? (y/N): ").strip().lower()
75
+ if response != 'y':
76
+ print("Setup cancelled. Please login first with: huggingface-cli login")
77
+ sys.exit(0)
78
+ print()
79
+ else:
80
+ print("✅ Hugging Face authentication verified")
81
+ print()
82
+
83
+ # Get Hugging Face username
84
+ hf_user = args.hf_user or os.getenv("HF_USER")
85
+ if not hf_user:
86
+ print("⚠️ HF_USER not found in .env file or environment")
87
+ hf_user = input("Enter your Hugging Face username: ").strip()
88
+ if not hf_user:
89
+ print("❌ Hugging Face username is required")
90
+ sys.exit(1)
91
+ print(f"💡 Tip: Add HF_USER to your .env file to skip this prompt:")
92
+ print(f" echo 'HF_USER={hf_user}' >> .env")
93
+ print()
94
+
95
+ # Get dataset name
96
+ dataset_name = args.dataset_name
97
+ if not dataset_name:
98
+ print("Dataset name:")
99
+ print(" Press Enter for default: 'mortis_manipulation'")
100
+ print(" Or enter a custom name (e.g., 'mortis_v2', 'test_dataset')")
101
+ user_input = input("Dataset name: ").strip()
102
+ dataset_name = user_input if user_input else "mortis_manipulation"
103
+ print()
104
+
105
+ # Create repository ID
106
+ repo_id = f"{hf_user}/{dataset_name}"
107
+
108
+ print(f"Creating dataset: {dataset_name}")
109
+ print(f"Repository: {repo_id}")
110
+ print()
111
+
112
+ # Create collector with custom name
113
+ collector = DataCollector(dataset_name, repo_id)
114
+
115
+ # Generate scripts
116
+ print("\nGenerating recording scripts...")
117
+ collector.generate_all_record_scripts()
118
+
119
+ # Show summary
120
+ collector.print_summary()
121
+
122
+ # Show instructions
123
+ collector.print_recording_instructions()
124
+
125
+ # Final instructions
126
+ print("="*70)
127
+ print("Setup Complete! 🎉")
128
+ print("="*70)
129
+ print()
130
+ print("Next steps:")
131
+ print(" 1. Make sure you're logged in to Hugging Face:")
132
+ print(" huggingface-cli login")
133
+ print(" 2. Connect your leader and follower robot arms")
134
+ print(" 3. Navigate to the scripts directory:")
135
+ print(f" cd {collector.dataset_dir}/scripts")
136
+ print(" 4. Run a recording script:")
137
+ print(" ./record_task_0.sh")
138
+ print()
139
+ print("Or record all tasks:")
140
+ print(" ./record_all_tasks.sh")
141
+ print()
142
+ print("="*70)
143
+
144
+
145
+ if __name__ == "__main__":
146
+ main()
src/mortis/setup_train.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ CLI tool for setting up Mortis training infrastructure.
4
+
5
+ This script generates lerobot-train scripts with appropriate
6
+ configurations for training SmolVLA models on Mortis datasets.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import argparse
12
+ from pathlib import Path
13
+ from dotenv import load_dotenv
14
+
15
+
16
+ class TrainingScriptGenerator:
17
+ """
18
+ Helper for generating lerobot-train scripts.
19
+
20
+ This class generates shell scripts that call lerobot-train with the
21
+ correct parameters for training SmolVLA models on Mortis datasets.
22
+
23
+ Attributes:
24
+ dataset_repo_id: Hugging Face dataset repository ID
25
+ output_dir: Directory for training outputs
26
+ job_name: Name for the training job
27
+ model_repo_id: Optional Hugging Face model repository ID for pushing
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ dataset_repo_id: str,
33
+ output_dir: str = "outputs/train",
34
+ job_name: str = "smolvla_mortis",
35
+ model_repo_id: str = None,
36
+ scripts_dir: str = "train"
37
+ ):
38
+ """
39
+ Initialize the TrainingScriptGenerator.
40
+
41
+ Args:
42
+ dataset_repo_id: Hugging Face dataset repository ID
43
+ output_dir: Base directory for training outputs (checkpoints, logs)
44
+ job_name: Name for the training job
45
+ model_repo_id: Optional HF model repo ID for pushing trained model
46
+ scripts_dir: Directory to save training scripts
47
+ """
48
+ self.dataset_repo_id = dataset_repo_id
49
+ self.output_dir = Path(output_dir)
50
+ self.job_name = job_name
51
+ self.model_repo_id = model_repo_id
52
+ self.scripts_dir = Path(scripts_dir)
53
+
54
+ # Create scripts directory
55
+ self.scripts_dir.mkdir(parents=True, exist_ok=True)
56
+
57
+ print(f"TrainingScriptGenerator initialized:")
58
+ print(f" Dataset: {self.dataset_repo_id}")
59
+ print(f" Scripts directory: {self.scripts_dir}")
60
+ print(f" Training output directory: {self.output_dir}")
61
+ print(f" Job name: {self.job_name}")
62
+ if self.model_repo_id:
63
+ print(f" Model repository: {self.model_repo_id}")
64
+
65
+ def generate_train_command(
66
+ self,
67
+ policy_path: str = "lerobot/smolvla_base",
68
+ batch_size: int = 16,
69
+ steps: int = 20000,
70
+ save_freq: int = 5000,
71
+ eval_freq: int = 5000,
72
+ n_action_steps: int = 50,
73
+ chunk_size: int = 50,
74
+ use_amp: bool = True,
75
+ enable_wandb: bool = True,
76
+ device: str = "cuda",
77
+ image_transforms: bool = True,
78
+ rename_map: str = None,
79
+ cuda_alloc_conf: str = "expandable_segments:True"
80
+ ) -> str:
81
+ """
82
+ Generate a lerobot-train command with specified parameters.
83
+
84
+ Args:
85
+ policy_path: Path to base policy (default: lerobot/smolvla_base)
86
+ batch_size: Training batch size
87
+ steps: Total training steps
88
+ save_freq: Checkpoint save frequency
89
+ eval_freq: Evaluation frequency
90
+ n_action_steps: Number of action steps to predict
91
+ chunk_size: Action chunk size
92
+ use_amp: Use automatic mixed precision
93
+ enable_wandb: Enable Weights & Biases logging
94
+ device: Device to use (cuda or cpu)
95
+ image_transforms: Enable image transformations
96
+ rename_map: Optional observation key rename mapping
97
+ cuda_alloc_conf: CUDA memory allocator configuration
98
+
99
+ Returns:
100
+ The complete lerobot-train command as a string
101
+ """
102
+ # Load environment variables
103
+ load_dotenv()
104
+
105
+ # Build output directory path
106
+ full_output_dir = self.output_dir / self.job_name
107
+
108
+ # Default rename map for SO101 with dual cameras
109
+ if rename_map is None:
110
+ rename_map = (
111
+ '{"observation.images.camera1": "observation.images.camera1", '
112
+ '"observation.images.camera2": "observation.images.camera2"}'
113
+ )
114
+
115
+ # Build the command
116
+ cmd_parts = [
117
+ f"PYTORCH_CUDA_ALLOC_CONF={cuda_alloc_conf} \\",
118
+ "lerobot-train \\",
119
+ f" --policy.path={policy_path} \\",
120
+ f" --dataset.repo_id={self.dataset_repo_id} \\",
121
+ f" --dataset.image_transforms.enable={str(image_transforms).lower()} \\",
122
+ f" --policy.device={device} \\",
123
+ f" --policy.use_amp={str(use_amp).lower()} \\",
124
+ f" --policy.n_action_steps={n_action_steps} \\",
125
+ f" --policy.chunk_size={chunk_size} \\",
126
+ f" --batch_size={batch_size} \\",
127
+ f" --steps={steps} \\",
128
+ f" --save_checkpoint=true \\",
129
+ f" --save_freq={save_freq} \\",
130
+ f" --eval_freq={eval_freq} \\",
131
+ f" --wandb.enable={str(enable_wandb).lower()} \\",
132
+ f" --output_dir={full_output_dir} \\",
133
+ f" --job_name={self.job_name} \\",
134
+ ]
135
+
136
+ # Add model repo ID if specified
137
+ if self.model_repo_id:
138
+ cmd_parts.append(f" --policy.repo_id={self.model_repo_id} \\")
139
+
140
+ # Add rename map
141
+ cmd_parts.append(f" --rename_map='{rename_map}'")
142
+
143
+ return "\n".join(cmd_parts)
144
+
145
+ def generate_training_script(
146
+ self,
147
+ script_name: str = "train.sh",
148
+ **kwargs
149
+ ) -> Path:
150
+ """
151
+ Generate a shell script for training.
152
+
153
+ Args:
154
+ script_name: Name for the training script
155
+ **kwargs: Additional arguments passed to generate_train_command
156
+
157
+ Returns:
158
+ Path to the generated script
159
+ """
160
+ script_path = self.scripts_dir / script_name
161
+
162
+ with open(script_path, 'w') as f:
163
+ f.write("#!/bin/bash\n")
164
+ f.write(f"# Training script for {self.job_name}\n")
165
+ f.write(f"# Dataset: {self.dataset_repo_id}\n")
166
+ f.write(f"# Generated by setup_train.py\n\n")
167
+
168
+ f.write("# Check if CUDA is available\n")
169
+ f.write("if ! command -v nvidia-smi &> /dev/null; then\n")
170
+ f.write(' echo "⚠️ Warning: nvidia-smi not found. CUDA may not be available."\n')
171
+ f.write(' read -p "Continue anyway? (y/N): " -n 1 -r\n')
172
+ f.write(' echo\n')
173
+ f.write(' if [[ ! $REPLY =~ ^[Yy]$ ]]; then\n')
174
+ f.write(' exit 1\n')
175
+ f.write(' fi\n')
176
+ f.write("fi\n\n")
177
+
178
+ f.write("# Start training\n")
179
+ f.write(f'echo "Starting training: {self.job_name}"\n')
180
+ f.write(f'echo "Dataset: {self.dataset_repo_id}"\n')
181
+ f.write(f'echo "Output: {self.output_dir / self.job_name}"\n')
182
+ f.write('echo ""\n\n')
183
+
184
+ f.write(self.generate_train_command(**kwargs))
185
+ f.write("\n")
186
+
187
+ # Make script executable
188
+ script_path.chmod(0o755)
189
+ print(f"Created: {script_path}")
190
+
191
+ return script_path
192
+
193
+ def generate_training_configs(self):
194
+ """
195
+ Generate multiple training scripts with different configurations.
196
+
197
+ Creates:
198
+ - train_quick.sh: Quick test training (1000 steps)
199
+ - train_standard.sh: Standard training (20k steps)
200
+ - train_full.sh: Full training (100k steps)
201
+ """
202
+ configs = [
203
+ {
204
+ "script_name": "train_quick.sh",
205
+ "steps": 1000,
206
+ "save_freq": 500,
207
+ "eval_freq": 500,
208
+ "batch_size": 8,
209
+ },
210
+ {
211
+ "script_name": "train_standard.sh",
212
+ "steps": 20000,
213
+ "save_freq": 5000,
214
+ "eval_freq": 5000,
215
+ "batch_size": 16,
216
+ },
217
+ {
218
+ "script_name": "train_full.sh",
219
+ "steps": 100000,
220
+ "save_freq": 10000,
221
+ "eval_freq": 10000,
222
+ "batch_size": 16,
223
+ },
224
+ ]
225
+
226
+ for config in configs:
227
+ self.generate_training_script(**config)
228
+
229
+ print(f"\n✅ Generated {len(configs)} training scripts in {self.scripts_dir}")
230
+
231
+ def print_usage_instructions(self):
232
+ """Print instructions for using the generated training scripts."""
233
+ print("\n" + "="*70)
234
+ print("Training Scripts Generated")
235
+ print("="*70)
236
+ print()
237
+ print("Available training scripts:")
238
+ print(f" {self.scripts_dir}/train_quick.sh - Quick test (1k steps)")
239
+ print(f" {self.scripts_dir}/train_standard.sh - Standard training (20k steps)")
240
+ print(f" {self.scripts_dir}/train_full.sh - Full training (100k steps)")
241
+ print()
242
+ print("To start training:")
243
+ print(f" cd {self.scripts_dir}")
244
+ print(" ./train_standard.sh")
245
+ print()
246
+ print("Training outputs will be saved to:")
247
+ print(f" {self.output_dir}/{self.job_name}/")
248
+ print()
249
+ print("Monitor training:")
250
+ print(" - Console: Watch the terminal output")
251
+ print(" - W&B: https://wandb.ai (if enabled)")
252
+ print(f" - Checkpoints: {self.output_dir}/{self.job_name}/checkpoints/")
253
+ print()
254
+ print("Resume training:")
255
+ print(" Add --resume=true to the lerobot-train command")
256
+ print()
257
+ print("="*70)
258
+
259
+
260
+ def main():
261
+ """Main entry point for training setup."""
262
+ # Parse command line arguments
263
+ parser = argparse.ArgumentParser(
264
+ description="Setup Mortis training infrastructure and generate training scripts"
265
+ )
266
+ parser.add_argument(
267
+ "--dataset-repo-id",
268
+ type=str,
269
+ required=True,
270
+ help="Hugging Face dataset repository ID (e.g., username/dataset-name)"
271
+ )
272
+ parser.add_argument(
273
+ "--output-dir",
274
+ type=str,
275
+ default="outputs/train",
276
+ help="Base directory for training outputs/checkpoints (default: outputs/train)"
277
+ )
278
+ parser.add_argument(
279
+ "--scripts-dir",
280
+ type=str,
281
+ default="train",
282
+ help="Directory to save training scripts (default: train)"
283
+ )
284
+ parser.add_argument(
285
+ "--job-name",
286
+ type=str,
287
+ default=None,
288
+ help="Name for the training job (default: derived from dataset name)"
289
+ )
290
+ parser.add_argument(
291
+ "--model-repo-id",
292
+ type=str,
293
+ default=None,
294
+ help="Hugging Face model repository ID for pushing trained model"
295
+ )
296
+ parser.add_argument(
297
+ "--batch-size",
298
+ type=int,
299
+ default=16,
300
+ help="Training batch size (default: 16)"
301
+ )
302
+ parser.add_argument(
303
+ "--steps",
304
+ type=int,
305
+ default=20000,
306
+ help="Total training steps (default: 20000)"
307
+ )
308
+ parser.add_argument(
309
+ "--policy-path",
310
+ type=str,
311
+ default="lerobot/smolvla_base",
312
+ help="Path to base policy (default: lerobot/smolvla_base)"
313
+ )
314
+ parser.add_argument(
315
+ "--no-wandb",
316
+ action="store_true",
317
+ help="Disable Weights & Biases logging"
318
+ )
319
+ parser.add_argument(
320
+ "--generate-configs",
321
+ action="store_true",
322
+ help="Generate multiple training configurations (quick, standard, full)"
323
+ )
324
+
325
+ args = parser.parse_args()
326
+
327
+ # Load environment variables
328
+ REPO_ROOT = Path(__file__).resolve().parents[2]
329
+ load_dotenv(REPO_ROOT / ".env")
330
+
331
+ print("="*70)
332
+ print("Mortis Training Setup")
333
+ print("="*70)
334
+ print()
335
+
336
+ # Derive job name from dataset if not provided
337
+ job_name = args.job_name
338
+ if not job_name:
339
+ # Extract dataset name from repo_id
340
+ dataset_name = args.dataset_repo_id.split('/')[-1]
341
+ job_name = f"smolvla_{dataset_name}"
342
+ print(f"Using job name: {job_name}")
343
+ print()
344
+
345
+ # Create generator
346
+ generator = TrainingScriptGenerator(
347
+ dataset_repo_id=args.dataset_repo_id,
348
+ output_dir=args.output_dir,
349
+ job_name=job_name,
350
+ model_repo_id=args.model_repo_id,
351
+ scripts_dir=args.scripts_dir
352
+ )
353
+
354
+ print()
355
+
356
+ if args.generate_configs:
357
+ # Generate multiple configurations
358
+ print("Generating training configurations...")
359
+ generator.generate_training_configs()
360
+ else:
361
+ # Generate single training script
362
+ print("Generating training script...")
363
+ generator.generate_training_script(
364
+ script_name="train.sh",
365
+ policy_path=args.policy_path,
366
+ batch_size=args.batch_size,
367
+ steps=args.steps,
368
+ enable_wandb=not args.no_wandb
369
+ )
370
+
371
+ # Print usage instructions
372
+ generator.print_usage_instructions()
373
+
374
+ # Final tips
375
+ print("\n💡 Tips:")
376
+ print(" - Adjust batch_size based on your GPU memory")
377
+ print(" - Monitor GPU usage with: watch -n 1 nvidia-smi")
378
+ print(" - Training logs are saved in the output directory")
379
+ print(" - Use Ctrl+C to stop training (checkpoints are saved)")
380
+ print()
381
+ print("="*70)
382
+
383
+
384
+ if __name__ == "__main__":
385
+ main()
src/mortis/smolvla_executor.py ADDED
@@ -0,0 +1,1040 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SmolVLA Executor for vision-language-action robotic manipulation.
3
+
4
+ This module implements the SmolVLA model executor that performs inference
5
+ for manipulation tasks using the trained SmolVLA policy from LeRobot.
6
+ """
7
+
8
+ import os
9
+ import time
10
+ import logging
11
+ from pathlib import Path
12
+ from typing import Optional, Dict, Any, Tuple
13
+ from threading import Lock, Event
14
+
15
+ import torch
16
+ import numpy as np
17
+ from PIL import Image as PILImage
18
+
19
+ # LeRobot imports
20
+ from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
21
+ from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
22
+
23
+ # Local imports
24
+ from .robot import MortisArm, HOME_POSE
25
+
26
+
27
+ # Configure logging
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class SmolVLAError(Exception):
32
+ """Base exception for SmolVLA executor errors."""
33
+ pass
34
+
35
+
36
+ class SafetyViolationError(SmolVLAError):
37
+ """Exception raised when a safety constraint is violated."""
38
+ pass
39
+
40
+
41
+ class TimeoutError(SmolVLAError):
42
+ """Exception raised when execution exceeds timeout."""
43
+ pass
44
+
45
+
46
+ class GPUOutOfMemoryError(SmolVLAError):
47
+ """Exception raised when GPU runs out of memory."""
48
+ pass
49
+
50
+
51
+ class SmolVLAExecutor:
52
+ """
53
+ Executor for SmolVLA vision-language-action model inference.
54
+
55
+ This class handles loading the trained SmolVLA model, capturing observations
56
+ from the robot and camera, running inference, and executing predicted actions
57
+ on the SO101 robotic arm.
58
+
59
+ Attributes:
60
+ checkpoint_path: Path to the trained model checkpoint
61
+ device: Device to run inference on ('cuda' or 'cpu')
62
+ policy: Loaded SmolVLA policy model
63
+ robot_arm: Reference to MortisArm instance for action execution
64
+ camera: Camera interface for visual observations (to be implemented)
65
+ valid_commands: List of trained manipulation task commands
66
+ """
67
+
68
+ # Valid manipulation commands that the model was trained on
69
+ VALID_COMMANDS = [
70
+ "Pick up the skull and place it in the green cup",
71
+ "Pick up the skull and place it in the orange cup",
72
+ "Pick up the skull and place it in the purple cup",
73
+ "Pick up the eyeball and place it in the green cup",
74
+ "Pick up the eyeball and place it in the orange cup",
75
+ "Pick up the eyeball and place it in the purple cup",
76
+ ]
77
+
78
+ # Safety limits for joint positions (in degrees)
79
+ # These define the safe workspace boundaries
80
+ # Based on SO101 calibration and physical constraints
81
+ JOINT_LIMITS = {
82
+ "shoulder_pan.pos": (-180, 180),
83
+ "shoulder_lift.pos": (-120, 120), # Extended range for SO101
84
+ "elbow_flex.pos": (-135, 135),
85
+ "wrist_flex.pos": (-105, 105), # Extended range for SO101
86
+ "wrist_roll.pos": (-180, 180),
87
+ "gripper.pos": (0, 100), # 0=open, 100=closed
88
+ }
89
+
90
+ # Maximum allowed joint velocity (degrees per step)
91
+ MAX_JOINT_VELOCITY = 10.0
92
+
93
+ # Default execution timeout (seconds)
94
+ DEFAULT_TIMEOUT = 30.0
95
+
96
+ def __init__(
97
+ self,
98
+ checkpoint_path: str,
99
+ robot_arm: Optional[MortisArm] = None,
100
+ device: Optional[str] = None,
101
+ enable_safety_checks: bool = True,
102
+ timeout: Optional[float] = None
103
+ ):
104
+ """
105
+ Initialize the SmolVLA executor.
106
+
107
+ Args:
108
+ checkpoint_path: Path to the trained SmolVLA model checkpoint
109
+ robot_arm: Optional MortisArm instance (will create if not provided)
110
+ device: Device to run inference on ('cuda', 'cpu', or None for auto-detect)
111
+ enable_safety_checks: Whether to enable workspace safety checks
112
+ timeout: Execution timeout in seconds (None for default)
113
+
114
+ Raises:
115
+ SmolVLAError: If checkpoint path doesn't exist or model loading fails
116
+ """
117
+ # Initialize attributes first (for cleanup in case of early failure)
118
+ self.camera = None
119
+ self.policy = None
120
+ self.preprocessor = None
121
+ self.postprocessor = None
122
+
123
+ self.checkpoint_path = Path(checkpoint_path)
124
+
125
+ # Validate checkpoint path
126
+ if not self.checkpoint_path.exists():
127
+ raise SmolVLAError(f"Checkpoint path does not exist: {checkpoint_path}")
128
+
129
+ # Set device
130
+ if device is None:
131
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
132
+ else:
133
+ self.device = device
134
+
135
+ logger.info(f"Initializing SmolVLA executor on device: {self.device}")
136
+
137
+ # Safety configuration
138
+ self.enable_safety_checks = enable_safety_checks
139
+ self.timeout = timeout if timeout is not None else self.DEFAULT_TIMEOUT
140
+
141
+ # Emergency stop flag and lock
142
+ self._emergency_stop_flag = Event()
143
+ self._execution_lock = Lock()
144
+ self._is_executing = False
145
+
146
+ # Previous state for velocity checking
147
+ self._previous_state = None
148
+
149
+ logger.info(f"Safety checks: {'enabled' if enable_safety_checks else 'disabled'}")
150
+ logger.info(f"Execution timeout: {self.timeout}s")
151
+
152
+ # Initialize robot arm
153
+ self.robot_arm = robot_arm
154
+ if self.robot_arm is None:
155
+ logger.info("No robot arm provided, creating new MortisArm instance")
156
+ self.robot_arm = MortisArm()
157
+
158
+ # Load the model
159
+ self._load_model()
160
+
161
+ # Model is ready
162
+ logger.info("SmolVLA executor initialized successfully")
163
+
164
+ def _load_model(self):
165
+ """
166
+ Load the SmolVLA model from checkpoint.
167
+
168
+ Raises:
169
+ SmolVLAError: If model loading fails
170
+ """
171
+ try:
172
+ logger.info(f"Loading SmolVLA model from: {self.checkpoint_path}")
173
+
174
+ # Load configuration - handle extra fields in config.json
175
+ import json
176
+ config_path = self.checkpoint_path / "config.json"
177
+
178
+ # Load config - ensure 'type' field is set to 'smolvla'
179
+ config_path = self.checkpoint_path / "config.json"
180
+
181
+ if config_path.exists():
182
+ # Load config
183
+ with open(config_path, 'r') as f:
184
+ config_dict = json.load(f)
185
+
186
+ # Ensure 'type' field is set to 'smolvla'
187
+ if 'type' not in config_dict or config_dict['type'] != 'smolvla':
188
+ logger.debug("Setting 'type' field to 'smolvla' in config")
189
+ config_dict['type'] = 'smolvla'
190
+
191
+ # Save updated config back
192
+ with open(config_path, 'w') as f:
193
+ json.dump(config_dict, f, indent=2)
194
+
195
+ # Get VLM model name for tokenizer
196
+ vlm_model_name = config_dict.get('vlm_model_name', 'HuggingFaceTB/SmolVLM2-500M-Video-Instruct')
197
+ else:
198
+ vlm_model_name = 'HuggingFaceTB/SmolVLM2-500M-Video-Instruct'
199
+
200
+ # Load policy using from_pretrained (it will load the config automatically)
201
+ self.policy = SmolVLAPolicy.from_pretrained(str(self.checkpoint_path))
202
+
203
+ # Move to device
204
+ self.policy.to(self.device)
205
+
206
+ # Set to evaluation mode
207
+ self.policy.eval()
208
+
209
+ logger.info("SmolVLA model loaded successfully")
210
+
211
+ # Load preprocessor (handles tokenization automatically)
212
+ self._load_preprocessor()
213
+
214
+ # Perform warmup inference
215
+ self._warmup()
216
+
217
+ except Exception as e:
218
+ logger.error(f"Failed to load SmolVLA model: {e}")
219
+ raise SmolVLAError(f"Model loading failed: {e}")
220
+
221
+ def _load_preprocessor(self):
222
+ """
223
+ Load preprocessor from checkpoint.
224
+
225
+ The preprocessor handles automatic tokenization of task strings
226
+ through the TokenizerProcessorStep.
227
+
228
+ Raises:
229
+ SmolVLAError: If preprocessor loading fails
230
+ """
231
+ try:
232
+ from lerobot.policies.factory import make_pre_post_processors
233
+
234
+ logger.info("Loading preprocessor from checkpoint...")
235
+
236
+ # Load preprocessor and postprocessor using policy config
237
+ self.preprocessor, self.postprocessor = make_pre_post_processors(
238
+ self.policy.config,
239
+ pretrained_path=str(self.checkpoint_path),
240
+ device=self.device
241
+ )
242
+
243
+ logger.info("Preprocessor and postprocessor loaded successfully")
244
+
245
+ except Exception as e:
246
+ logger.error(f"Failed to load preprocessor: {e}")
247
+ raise SmolVLAError(f"Preprocessor loading failed: {e}")
248
+
249
+ def _warmup(self):
250
+ """
251
+ Perform warmup inference to initialize CUDA kernels and caches.
252
+
253
+ This reduces latency for the first real inference call.
254
+ """
255
+ if self.device == "cuda":
256
+ logger.info("Performing model warmup...")
257
+ try:
258
+ # Create dummy observation
259
+ dummy_obs = self._create_dummy_observation()
260
+
261
+ # Run dummy inference
262
+ with torch.no_grad():
263
+ # SmolVLA expects a batch of observations
264
+ result = self.policy.select_action(dummy_obs)
265
+ # Result may be a dict with 'action' key or just a tensor
266
+ if isinstance(result, dict):
267
+ _ = result.get('action', result)
268
+
269
+ # Clear cache
270
+ torch.cuda.empty_cache()
271
+
272
+ logger.info("Model warmup complete")
273
+ except Exception as e:
274
+ # Warmup is optional - log but don't fail
275
+ logger.debug(f"Warmup skipped: {e}")
276
+ pass
277
+
278
+ def _create_dummy_observation(self) -> Dict[str, torch.Tensor]:
279
+ """
280
+ Create a dummy observation for warmup.
281
+
282
+ Returns:
283
+ Dictionary with dummy observation tensors
284
+ """
285
+ # Create dummy state
286
+ dummy_state = torch.zeros(1, 6, dtype=torch.float32, device=self.device)
287
+
288
+ # Create dummy images
289
+ dummy_image = self._create_dummy_image()
290
+
291
+ observation = {
292
+ "observation.images.camera1": dummy_image,
293
+ "observation.images.camera2": dummy_image.clone(),
294
+ "observation.images.camera3": dummy_image.clone(),
295
+ "observation.state": dummy_state,
296
+ "task": "dummy task" # Task as string (preprocessor will handle it)
297
+ }
298
+
299
+ # Apply preprocessor to tokenize task
300
+ if self.preprocessor is not None:
301
+ observation = self.preprocessor(observation)
302
+
303
+ return observation
304
+
305
+ def validate_command(self, command: str) -> bool:
306
+ """
307
+ Validate that a command is in the trained task set.
308
+
309
+ Args:
310
+ command: The manipulation command to validate
311
+
312
+ Returns:
313
+ True if command is valid, False otherwise
314
+ """
315
+ return command in self.VALID_COMMANDS
316
+
317
+ def trigger_emergency_stop(self):
318
+ """
319
+ Trigger emergency stop from external thread.
320
+
321
+ This can be called from another thread to safely stop execution.
322
+ """
323
+ logger.warning("Emergency stop triggered externally")
324
+ self._emergency_stop_flag.set()
325
+
326
+ def is_executing(self) -> bool:
327
+ """
328
+ Check if executor is currently running a task.
329
+
330
+ Returns:
331
+ True if a task is being executed
332
+ """
333
+ return self._is_executing
334
+
335
+ def execute(self, command: str, max_steps: int = 500, timeout: Optional[float] = None) -> bool:
336
+ """
337
+ Execute a manipulation task using SmolVLA inference.
338
+
339
+ This is the main entry point for executing manipulation commands.
340
+ It runs the inference loop, capturing observations and executing
341
+ predicted actions until the task is complete or max_steps is reached.
342
+
343
+ Args:
344
+ command: Natural language task description (must be in VALID_COMMANDS)
345
+ max_steps: Maximum number of inference steps to execute
346
+ timeout: Optional timeout override (seconds)
347
+
348
+ Returns:
349
+ True if execution completed successfully, False otherwise
350
+
351
+ Raises:
352
+ SmolVLAError: If command is invalid or execution fails critically
353
+ SafetyViolationError: If safety constraints are violated
354
+ TimeoutError: If execution exceeds timeout
355
+ """
356
+ # Acquire execution lock to prevent concurrent execution
357
+ if not self._execution_lock.acquire(blocking=False):
358
+ raise SmolVLAError("Executor is already running a task")
359
+
360
+ try:
361
+ # Clear emergency stop flag
362
+ self._emergency_stop_flag.clear()
363
+ self._is_executing = True
364
+
365
+ # Validate command against trained task set
366
+ if not self.validate_command(command):
367
+ raise SmolVLAError(
368
+ f"Invalid command: '{command}'. "
369
+ f"Must be one of: {self.VALID_COMMANDS}"
370
+ )
371
+
372
+ # Ensure robot is connected
373
+ if not self.robot_arm.connected:
374
+ logger.info("Robot not connected, attempting to connect...")
375
+ self.robot_arm.connect()
376
+ if not self.robot_arm.connected:
377
+ raise SmolVLAError("Failed to connect to robot arm")
378
+
379
+ # Use provided timeout or default
380
+ execution_timeout = timeout if timeout is not None else self.timeout
381
+
382
+ logger.info(f"Starting SmolVLA execution: '{command}'")
383
+ logger.info(f"Max steps: {max_steps}, Timeout: {execution_timeout}s")
384
+ logger.info(f"Safety checks: {'enabled' if self.enable_safety_checks else 'disabled'}")
385
+
386
+ try:
387
+ # Execute the task with timeout
388
+ success = self._execute_task_with_timeout(command, max_steps, execution_timeout)
389
+
390
+ if success:
391
+ logger.info(f"Task completed successfully: '{command}'")
392
+ else:
393
+ logger.warning(f"Task did not complete within constraints")
394
+
395
+ # Return to home position safely
396
+ logger.info("Returning to home position...")
397
+ self._safe_return_home()
398
+
399
+ return success
400
+
401
+ except TimeoutError as e:
402
+ logger.error(f"Execution timeout: {e}")
403
+ self._emergency_stop()
404
+ raise
405
+ except SafetyViolationError as e:
406
+ logger.error(f"Safety violation: {e}")
407
+ self._emergency_stop()
408
+ raise
409
+ except GPUOutOfMemoryError as e:
410
+ logger.error(f"GPU out of memory: {e}")
411
+ self._handle_gpu_oom()
412
+ self._emergency_stop()
413
+ raise
414
+ except Exception as e:
415
+ logger.error(f"Execution failed: {e}")
416
+ import traceback
417
+ logger.error(f"Traceback: {traceback.format_exc()}")
418
+ self._emergency_stop()
419
+ raise SmolVLAError(f"Execution failed: {e}")
420
+
421
+ finally:
422
+ # Always release lock and reset execution flag
423
+ self._is_executing = False
424
+ self._execution_lock.release()
425
+
426
+ def _execute_task_with_timeout(self, command: str, max_steps: int, timeout: float) -> bool:
427
+ """
428
+ Execute task with timeout monitoring.
429
+
430
+ Args:
431
+ command: The manipulation command
432
+ max_steps: Maximum steps
433
+ timeout: Timeout in seconds
434
+
435
+ Returns:
436
+ True if task completed successfully
437
+
438
+ Raises:
439
+ TimeoutError: If execution exceeds timeout
440
+ """
441
+ start_time = time.time()
442
+
443
+ try:
444
+ return self._execute_task(command, max_steps, start_time, timeout)
445
+ except Exception as e:
446
+ elapsed = time.time() - start_time
447
+ if elapsed >= timeout:
448
+ raise TimeoutError(f"Execution exceeded timeout of {timeout}s")
449
+ raise
450
+
451
+ def _execute_task(self, command: str, max_steps: int, start_time: float, timeout: float) -> bool:
452
+ """
453
+ Internal method to execute the task inference loop.
454
+
455
+ This method implements the core inference loop:
456
+ 1. Capture visual and state observations
457
+ 2. Run SmolVLA inference to predict next action
458
+ 3. Execute action on robot
459
+ 4. Check for task completion
460
+ 5. Repeat until complete or max_steps reached
461
+
462
+ Args:
463
+ command: The manipulation command to execute
464
+ max_steps: Maximum number of steps
465
+
466
+ Returns:
467
+ True if task completed, False if max steps reached
468
+ """
469
+ # Reset task completion tracking variables
470
+ self._previous_action = None
471
+ self._stable_count = 0
472
+ self._previous_state = None
473
+
474
+ # Track execution metrics
475
+ last_progress_log = 0
476
+ progress_log_interval = 50 # Log every 50 steps
477
+
478
+ with torch.no_grad():
479
+ for step in range(max_steps):
480
+ # Check for emergency stop
481
+ if self._emergency_stop_flag.is_set():
482
+ logger.warning("Emergency stop detected, aborting execution")
483
+ return False
484
+
485
+ # Check timeout
486
+ elapsed = time.time() - start_time
487
+ if elapsed >= timeout:
488
+ raise TimeoutError(f"Execution exceeded timeout of {timeout}s at step {step}")
489
+
490
+ # Log progress periodically
491
+ if step - last_progress_log >= progress_log_interval:
492
+ fps = step / elapsed if elapsed > 0 else 0
493
+ logger.info(
494
+ f"Execution progress: step {step}/{max_steps} "
495
+ f"({step/max_steps*100:.1f}%) - {fps:.1f} FPS - {elapsed:.1f}s elapsed"
496
+ )
497
+ last_progress_log = step
498
+
499
+ try:
500
+ # Capture current observation
501
+ observation = self._get_observation()
502
+
503
+ # Add task string (preprocessor will tokenize it)
504
+ observation = self._add_task_string(observation, command)
505
+
506
+ # Apply preprocessor (tokenizes task string automatically)
507
+ observation = self.preprocessor(observation)
508
+
509
+ # Run inference to predict next action (normalized)
510
+ action_normalized = self._run_inference_with_oom_handling(observation)
511
+
512
+ # Debug: log normalized action
513
+ logger.debug(f"Normalized action type: {type(action_normalized)}, shape: {action_normalized.shape if hasattr(action_normalized, 'shape') else 'N/A'}")
514
+
515
+ # Denormalize action using postprocessor
516
+ action = self.postprocessor(action_normalized)
517
+
518
+ # Debug: log denormalized action
519
+ logger.debug(f"Denormalized action: {action}")
520
+
521
+ # Validate action safety (on denormalized action)
522
+ if self.enable_safety_checks:
523
+ self._check_action_safety(action, observation)
524
+
525
+ # Send action to robot
526
+ self._send_action(action)
527
+
528
+ # Check if task is complete (use normalized action for stability check)
529
+ try:
530
+ is_complete = self._is_task_complete(observation, step, action_normalized)
531
+ if is_complete:
532
+ elapsed = time.time() - start_time
533
+ logger.info(
534
+ f"Task completed at step {step} "
535
+ f"(elapsed: {elapsed:.2f}s, avg FPS: {step/elapsed:.1f})"
536
+ )
537
+ return True
538
+ except Exception as e:
539
+ logger.error(f"Error in _is_task_complete: {e}")
540
+ raise
541
+
542
+ # Small delay between steps to maintain ~30 FPS
543
+ time.sleep(0.033)
544
+
545
+ except torch.cuda.OutOfMemoryError as e:
546
+ logger.error(f"GPU out of memory at step {step}")
547
+ raise GPUOutOfMemoryError(f"GPU OOM at step {step}: {e}")
548
+ except SafetyViolationError:
549
+ # Re-raise safety violations
550
+ raise
551
+ except Exception as e:
552
+ logger.error(f"Error at step {step}: {e}")
553
+ raise
554
+
555
+ # Max steps reached without completion
556
+ elapsed = time.time() - start_time
557
+ logger.warning(
558
+ f"Task did not complete within {max_steps} steps "
559
+ f"(elapsed: {elapsed:.2f}s)"
560
+ )
561
+ return False
562
+
563
+ def _get_observation(self) -> Dict[str, torch.Tensor]:
564
+ """
565
+ Get current robot observation (image + state).
566
+
567
+ Captures robot state from robot.get_observation() and images from cameras.
568
+
569
+ Returns:
570
+ Dictionary with observation tensors formatted for SmolVLA:
571
+ - observation.images.camera1: RGB image tensor [1, 3, H, W]
572
+ - observation.images.camera2: RGB image tensor [1, 3, H, W] (if available)
573
+ - observation.images.camera3: RGB image tensor [1, 3, H, W] (if available)
574
+ - observation.state: Joint positions tensor [1, 6]
575
+ """
576
+ try:
577
+ # Get robot state (joint positions)
578
+ robot_obs = self.robot_arm.robot.get_observation()
579
+
580
+ # Extract joint positions in order
581
+ joint_names = [
582
+ "shoulder_pan.pos",
583
+ "shoulder_lift.pos",
584
+ "elbow_flex.pos",
585
+ "wrist_flex.pos",
586
+ "wrist_roll.pos",
587
+ "gripper.pos"
588
+ ]
589
+
590
+ # Build state vector
591
+ state_values = [robot_obs[name] for name in joint_names]
592
+ state_tensor = torch.tensor(
593
+ state_values,
594
+ dtype=torch.float32,
595
+ device=self.device
596
+ ).unsqueeze(0) # Add batch dimension
597
+
598
+ # Get camera images (robot.cameras is a dict of camera objects)
599
+ observation = {"observation.state": state_tensor}
600
+
601
+ if hasattr(self.robot_arm.robot, 'cameras') and self.robot_arm.robot.cameras:
602
+ # Get images from robot's cameras
603
+ for i, (camera_name, camera) in enumerate(self.robot_arm.robot.cameras.items(), start=1):
604
+ try:
605
+ image = camera.read()
606
+ # Convert to tensor: (H, W, C) -> (1, C, H, W)
607
+ image_tensor = torch.from_numpy(image).float().permute(2, 0, 1).unsqueeze(0).to(self.device) / 255.0
608
+ observation[f"observation.images.camera{i}"] = image_tensor
609
+ logger.debug(f"Captured image from {camera_name}: shape={image.shape}")
610
+ except Exception as e:
611
+ logger.warning(f"Failed to read from {camera_name}: {e}")
612
+ observation[f"observation.images.camera{i}"] = self._create_dummy_image()
613
+ else:
614
+ logger.debug("No cameras configured on robot, using dummy images")
615
+
616
+ # Ensure we have 3 camera images (duplicate if needed)
617
+ for i in range(1, 4):
618
+ key = f"observation.images.camera{i}"
619
+ if key not in observation:
620
+ # Use first camera or dummy
621
+ if "observation.images.camera1" in observation:
622
+ observation[key] = observation["observation.images.camera1"].clone()
623
+ else:
624
+ observation[key] = self._create_dummy_image()
625
+
626
+ logger.debug(f"Captured observation with keys: {list(observation.keys())}")
627
+ return observation
628
+
629
+ except Exception as e:
630
+ logger.warning(f"Failed to get robot observation: {e}. Using dummy observation.")
631
+ return self._create_dummy_observation_without_task()
632
+
633
+ def _create_dummy_observation_without_task(self) -> Dict[str, torch.Tensor]:
634
+ """
635
+ Create a dummy observation without task string (for error recovery).
636
+
637
+ Returns:
638
+ Dictionary with dummy observation tensors
639
+ """
640
+ dummy_state = torch.zeros(1, 6, dtype=torch.float32, device=self.device)
641
+ dummy_image = self._create_dummy_image()
642
+
643
+ return {
644
+ "observation.images.camera1": dummy_image,
645
+ "observation.images.camera2": dummy_image.clone(),
646
+ "observation.images.camera3": dummy_image.clone(),
647
+ "observation.state": dummy_state,
648
+ }
649
+
650
+ def _add_task_string(self, observation: Dict[str, torch.Tensor], command: str) -> Dict[str, torch.Tensor]:
651
+ """
652
+ Add task string to observation.
653
+
654
+ The preprocessor will automatically tokenize this string through
655
+ the TokenizerProcessorStep.
656
+
657
+ Args:
658
+ observation: Current observation dictionary
659
+ command: Natural language command string
660
+
661
+ Returns:
662
+ Observation dictionary with added task string
663
+ """
664
+ # Simply add the task string - the preprocessor will tokenize it
665
+ observation["task"] = command
666
+
667
+ logger.debug(f"Added task string: '{command}'")
668
+
669
+ return observation
670
+
671
+ def _create_dummy_image(self) -> torch.Tensor:
672
+ """
673
+ Create a dummy image tensor for testing without camera.
674
+
675
+ Returns:
676
+ Dummy image tensor [1, 3, 256, 256] with batch dimension
677
+ """
678
+ # Create black image
679
+ dummy_image = torch.zeros(1, 3, 256, 256, dtype=torch.float32, device=self.device)
680
+ return dummy_image
681
+
682
+ def _send_action(self, action: torch.Tensor):
683
+ """
684
+ Send predicted action to robot.
685
+
686
+ Converts the action tensor from SmolVLA to SO101 command format
687
+ and sends it to the robot arm for execution.
688
+
689
+ Args:
690
+ action: Action tensor from policy (shape: [batch, action_dim])
691
+
692
+ Raises:
693
+ SmolVLAError: If action execution fails
694
+ """
695
+ try:
696
+ # Convert action tensor to robot command dictionary
697
+ action_dict = self._action_to_dict(action)
698
+
699
+ # Send to robot
700
+ self.robot_arm.robot.send_action(action_dict)
701
+
702
+ # Log action at debug level (verbose)
703
+ logger.debug(f"Action sent: {action_dict}")
704
+
705
+ except Exception as e:
706
+ logger.error(f"Failed to send action to robot: {e}")
707
+ raise SmolVLAError(f"Action execution failed: {e}")
708
+
709
+ def _action_to_dict(self, action: torch.Tensor) -> Dict[str, float]:
710
+ """
711
+ Convert action tensor to SO101 command format.
712
+
713
+ Maps the action tensor dimensions to SO101 joint names and converts
714
+ to the dictionary format expected by the robot driver.
715
+
716
+ Args:
717
+ action: Action tensor from policy (shape: [batch, 6] or [6])
718
+
719
+ Returns:
720
+ Dictionary mapping joint names to positions (in degrees or normalized units)
721
+
722
+ Raises:
723
+ SmolVLAError: If action tensor has invalid shape
724
+ """
725
+ # Remove batch dimension if present
726
+ if action.dim() > 1:
727
+ action = action.squeeze(0)
728
+
729
+ # Validate action dimension
730
+ if action.shape[0] != 6:
731
+ raise SmolVLAError(
732
+ f"Invalid action shape: expected 6 dimensions, got {action.shape[0]}"
733
+ )
734
+
735
+ # Convert to numpy
736
+ action_np = action.cpu().numpy()
737
+
738
+ # Map action dimensions to joint names
739
+ # Order must match the training data format
740
+ joint_names = [
741
+ "shoulder_pan.pos",
742
+ "shoulder_lift.pos",
743
+ "elbow_flex.pos",
744
+ "wrist_flex.pos",
745
+ "wrist_roll.pos",
746
+ "gripper.pos"
747
+ ]
748
+
749
+ # Create action dictionary
750
+ action_dict = {
751
+ name: float(action_np[i])
752
+ for i, name in enumerate(joint_names)
753
+ }
754
+
755
+ return action_dict
756
+
757
+ def _is_task_complete(
758
+ self,
759
+ observation: Dict[str, torch.Tensor],
760
+ step: int,
761
+ action: torch.Tensor
762
+ ) -> bool:
763
+ """
764
+ Determine if the task is complete.
765
+
766
+ This method uses multiple heuristics to detect task completion:
767
+ 1. Minimum step count (ensure task has progressed)
768
+ 2. Maximum step count (assume completion after sufficient time)
769
+ 3. Action stability (detect when robot has settled)
770
+
771
+ In a production system, this could be enhanced with:
772
+ - Learned termination classifier
773
+ - Visual goal detection
774
+ - Force/torque feedback
775
+ - Success detection from camera
776
+
777
+ Args:
778
+ observation: Current observation dictionary
779
+ step: Current step number
780
+ action: Predicted action tensor
781
+
782
+ Returns:
783
+ True if task should be considered complete
784
+ """
785
+ # Minimum steps before considering completion (allow task to progress)
786
+ MIN_STEPS = 100
787
+
788
+ # Maximum steps - assume task is complete after this many steps
789
+ # Most manipulation tasks should complete within 400-450 steps at 30 FPS
790
+ # (approximately 13-15 seconds)
791
+ MAX_STEPS = 450
792
+
793
+ # Early exit: not enough steps yet
794
+ if step < MIN_STEPS:
795
+ return False
796
+
797
+ # Late exit: max steps reached, consider complete
798
+ if step >= MAX_STEPS:
799
+ logger.info(f"Task completion: max steps ({MAX_STEPS}) reached")
800
+ return True
801
+
802
+ # Check for action stability (robot has settled into final position)
803
+ if hasattr(self, '_previous_action') and self._previous_action is not None:
804
+ action_diff = torch.abs(action - self._previous_action).max().item()
805
+
806
+ # If action changes are very small, robot may have settled
807
+ if action_diff < 0.01: # Threshold for "stable" action
808
+ if not hasattr(self, '_stable_count'):
809
+ self._stable_count = 0
810
+ self._stable_count += 1
811
+
812
+ # If stable for 30 consecutive steps (~1 second), consider complete
813
+ if self._stable_count >= 30:
814
+ logger.info(
815
+ f"Task completion: action stability detected at step {step} "
816
+ f"(stable for {self._stable_count} steps)"
817
+ )
818
+ return True
819
+ else:
820
+ # Reset stability counter if action changes significantly
821
+ self._stable_count = 0
822
+
823
+ # Store current action for next comparison
824
+ self._previous_action = action.clone()
825
+
826
+ # Not complete yet
827
+ return False
828
+
829
+ def _check_action_safety(self, action: torch.Tensor, observation: Dict[str, torch.Tensor]):
830
+ """
831
+ Check if predicted action is safe to execute.
832
+
833
+ Validates:
834
+ 1. Joint position limits
835
+ 2. Joint velocity limits
836
+ 3. Workspace boundaries
837
+
838
+ Args:
839
+ action: Predicted action tensor
840
+ observation: Current observation
841
+
842
+ Raises:
843
+ SafetyViolationError: If action violates safety constraints
844
+ """
845
+ # Convert action to dict for checking
846
+ action_dict = self._action_to_dict(action)
847
+
848
+ # Check joint position limits
849
+ for joint_name, position in action_dict.items():
850
+ if joint_name in self.JOINT_LIMITS:
851
+ min_pos, max_pos = self.JOINT_LIMITS[joint_name]
852
+ if position < min_pos or position > max_pos:
853
+ raise SafetyViolationError(
854
+ f"Joint {joint_name} position {position:.2f} exceeds limits "
855
+ f"[{min_pos}, {max_pos}]"
856
+ )
857
+
858
+ # Check joint velocity limits (if we have previous state)
859
+ if self._previous_state is not None:
860
+ current_state = observation["observation.state"].squeeze(0).cpu().numpy()
861
+ velocity = np.abs(current_state - self._previous_state)
862
+ max_velocity = np.max(velocity)
863
+
864
+ if max_velocity > self.MAX_JOINT_VELOCITY:
865
+ raise SafetyViolationError(
866
+ f"Joint velocity {max_velocity:.2f} exceeds limit "
867
+ f"{self.MAX_JOINT_VELOCITY}"
868
+ )
869
+
870
+ # Update previous state for next check
871
+ self._previous_state = observation["observation.state"].squeeze(0).cpu().numpy().copy()
872
+
873
+ def _run_inference_with_oom_handling(self, observation: Dict[str, torch.Tensor]) -> torch.Tensor:
874
+ """
875
+ Run inference with GPU out-of-memory handling.
876
+
877
+ Args:
878
+ observation: Current observation
879
+
880
+ Returns:
881
+ Predicted action tensor
882
+
883
+ Raises:
884
+ GPUOutOfMemoryError: If GPU runs out of memory
885
+ """
886
+ try:
887
+ result = self.policy.select_action(observation)
888
+
889
+ # Debug: log what we got back
890
+ logger.debug(f"Policy returned type: {type(result)}")
891
+ if isinstance(result, dict):
892
+ logger.debug(f"Policy returned dict keys: {result.keys()}")
893
+
894
+ # SmolVLA returns a dictionary with 'action' key
895
+ if isinstance(result, dict):
896
+ if 'action' in result:
897
+ return result['action']
898
+ else:
899
+ # Try to find the action in the dict
900
+ logger.error(f"Policy returned dict without 'action' key. Keys: {result.keys()}")
901
+ raise SmolVLAError(f"Policy returned unexpected format: {type(result)}")
902
+ return result
903
+ except torch.cuda.OutOfMemoryError as e:
904
+ logger.error("GPU out of memory during inference")
905
+ # Try to recover by clearing cache
906
+ torch.cuda.empty_cache()
907
+ # Try one more time
908
+ try:
909
+ result = self.policy.select_action(observation)
910
+ if isinstance(result, dict):
911
+ if 'action' in result:
912
+ return result['action']
913
+ else:
914
+ raise SmolVLAError(f"Policy returned unexpected format: {type(result)}")
915
+ return result
916
+ except torch.cuda.OutOfMemoryError:
917
+ raise GPUOutOfMemoryError("GPU out of memory, cannot recover")
918
+
919
+ def _handle_gpu_oom(self):
920
+ """
921
+ Handle GPU out-of-memory error by clearing cache and resetting state.
922
+ """
923
+ logger.info("Handling GPU out-of-memory error...")
924
+
925
+ if self.device == "cuda":
926
+ # Clear CUDA cache
927
+ torch.cuda.empty_cache()
928
+
929
+ # Log memory stats
930
+ if torch.cuda.is_available():
931
+ allocated = torch.cuda.memory_allocated() / 1024**3
932
+ reserved = torch.cuda.memory_reserved() / 1024**3
933
+ logger.info(f"GPU memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
934
+
935
+ logger.info("GPU memory cleared")
936
+
937
+ def _safe_return_home(self):
938
+ """
939
+ Safely return robot to home position with error handling.
940
+ """
941
+ try:
942
+ self.robot_arm.move_arm("idle")
943
+ logger.info("Robot returned to home position")
944
+ except Exception as e:
945
+ logger.error(f"Failed to return to home position: {e}")
946
+ # Try direct position command as fallback
947
+ try:
948
+ self.robot_arm.robot.send_action(HOME_POSE)
949
+ logger.info("Robot returned to home using direct command")
950
+ except Exception as e2:
951
+ logger.error(f"Direct home command also failed: {e2}")
952
+
953
+ def _emergency_stop(self):
954
+ """
955
+ Emergency stop: return robot to safe idle position.
956
+
957
+ This is called when an error occurs during execution.
958
+ Sets the emergency stop flag and attempts to safely stop the robot.
959
+ """
960
+ logger.warning("Emergency stop triggered")
961
+
962
+ # Set emergency stop flag
963
+ self._emergency_stop_flag.set()
964
+
965
+ try:
966
+ # Try to stop robot immediately
967
+ self._safe_return_home()
968
+ logger.info("Emergency stop completed - robot in safe position")
969
+ except Exception as e:
970
+ logger.error(f"Emergency stop failed: {e}")
971
+ logger.error("MANUAL INTERVENTION MAY BE REQUIRED")
972
+
973
+ def cleanup(self):
974
+ """
975
+ Clean up resources (camera, GPU memory, etc.).
976
+
977
+ Should be called when the executor is no longer needed.
978
+ """
979
+ logger.info("Cleaning up SmolVLA executor...")
980
+
981
+ # Disconnect camera
982
+ if hasattr(self, 'camera') and self.camera is not None:
983
+ try:
984
+ self.camera.disconnect()
985
+ except Exception as e:
986
+ logger.warning(f"Camera disconnect failed: {e}")
987
+
988
+ # Clear GPU memory
989
+ if hasattr(self, 'device') and self.device == "cuda":
990
+ torch.cuda.empty_cache()
991
+
992
+ logger.info("Cleanup complete")
993
+
994
+ def __del__(self):
995
+ """Destructor to ensure cleanup."""
996
+ try:
997
+ self.cleanup()
998
+ except Exception:
999
+ # Silently ignore cleanup errors in destructor
1000
+ pass
1001
+
1002
+
1003
+ def init_smolvla_executor(
1004
+ checkpoint_path: Optional[str] = None,
1005
+ robot_arm: Optional[MortisArm] = None,
1006
+ device: Optional[str] = None
1007
+ ) -> SmolVLAExecutor:
1008
+ """
1009
+ Factory function to initialize SmolVLA executor with environment configuration.
1010
+
1011
+ Args:
1012
+ checkpoint_path: Path to model checkpoint (uses env var if not provided)
1013
+ robot_arm: Optional MortisArm instance
1014
+ device: Device to use (uses env var or auto-detect if not provided)
1015
+
1016
+ Returns:
1017
+ Initialized SmolVLAExecutor instance
1018
+
1019
+ Raises:
1020
+ SmolVLAError: If initialization fails
1021
+ """
1022
+ # Get checkpoint path from environment if not provided
1023
+ if checkpoint_path is None:
1024
+ checkpoint_path = os.getenv("SMOLVLA_CHECKPOINT_PATH")
1025
+ if checkpoint_path is None:
1026
+ raise SmolVLAError(
1027
+ "No checkpoint path provided and SMOLVLA_CHECKPOINT_PATH not set"
1028
+ )
1029
+
1030
+ # Get device from environment if not provided
1031
+ if device is None:
1032
+ device = os.getenv("SMOLVLA_DEVICE")
1033
+
1034
+ logger.info(f"Initializing SmolVLA executor with checkpoint: {checkpoint_path}")
1035
+
1036
+ return SmolVLAExecutor(
1037
+ checkpoint_path=checkpoint_path,
1038
+ robot_arm=robot_arm,
1039
+ device=device
1040
+ )
src/mortis/stt_service.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Speech-to-Text service for Mortis voice input.
3
+
4
+ This module provides the STTService class for converting audio input to text,
5
+ with support for Gemini native audio processing and fallback to Google Cloud Speech-to-Text.
6
+ """
7
+
8
+ import os
9
+ import logging
10
+ from pathlib import Path
11
+ from typing import Optional, Literal
12
+ from enum import Enum
13
+ from dotenv import load_dotenv
14
+
15
+ from google import genai
16
+ from google.genai import types
17
+
18
+ # Load environment variables
19
+ REPO_ROOT = Path(__file__).resolve().parents[2]
20
+ load_dotenv(REPO_ROOT / ".env")
21
+
22
+ # Configure logging
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class STTProvider(Enum):
27
+ """Available Speech-to-Text providers."""
28
+ GEMINI = "gemini"
29
+ GOOGLE_STT = "google_stt"
30
+
31
+
32
+ class AudioFormat(Enum):
33
+ """Supported audio formats."""
34
+ WAV = "wav"
35
+ MP3 = "mp3"
36
+ WEBM = "webm"
37
+ OGG = "ogg"
38
+ FLAC = "flac"
39
+
40
+
41
+ class AudioProcessingError(Exception):
42
+ """Base exception for audio processing errors."""
43
+ pass
44
+
45
+
46
+ class STTService:
47
+ """
48
+ Speech-to-Text service for converting audio input to text.
49
+
50
+ Supports multiple STT providers:
51
+ - Gemini native audio (primary, recommended)
52
+ - Google Cloud Speech-to-Text (fallback)
53
+
54
+ The service automatically handles audio format validation and conversion.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ provider: Optional[STTProvider] = None,
60
+ api_key: Optional[str] = None,
61
+ model_name: Optional[str] = None,
62
+ language_code: str = "en-US",
63
+ enable_fallback: bool = True
64
+ ):
65
+ """
66
+ Initialize STT service.
67
+
68
+ Args:
69
+ provider: STT provider to use (defaults to GEMINI from env or GEMINI)
70
+ api_key: API key for Gemini (defaults to GEMINI_API_KEY env var)
71
+ model_name: Gemini model to use (defaults to GEMINI_MODEL env var or gemini-1.5-flash)
72
+ language_code: Language code for transcription (default: en-US)
73
+ enable_fallback: Whether to enable fallback to Google STT on Gemini failure
74
+ """
75
+ # Determine provider from environment or default to Gemini
76
+ if provider is None:
77
+ provider_str = os.getenv("STT_PROVIDER", "gemini").lower()
78
+ try:
79
+ provider = STTProvider(provider_str)
80
+ except ValueError:
81
+ logger.warning(f"Invalid STT_PROVIDER '{provider_str}', defaulting to GEMINI")
82
+ provider = STTProvider.GEMINI
83
+
84
+ self.provider = provider
85
+ self.language_code = language_code
86
+ self.enable_fallback = enable_fallback
87
+
88
+ # Initialize Gemini client for audio processing
89
+ self.api_key = api_key or os.getenv("GEMINI_API_KEY")
90
+ if not self.api_key:
91
+ raise ValueError("GEMINI_API_KEY must be provided or set in environment")
92
+
93
+ self.model_name = model_name or os.getenv("GEMINI_MODEL", "gemini-1.5-flash")
94
+ self.client = genai.Client(api_key=self.api_key)
95
+
96
+ # Initialize Google Cloud STT client (lazy loading)
97
+ self._google_stt_client = None
98
+
99
+ logger.info(
100
+ f"STTService initialized with provider: {self.provider.value}, "
101
+ f"model: {self.model_name}, language: {self.language_code}, "
102
+ f"fallback: {self.enable_fallback}"
103
+ )
104
+
105
+ def transcribe(self, audio_path: str) -> str:
106
+ """
107
+ Transcribe audio file to text.
108
+
109
+ Args:
110
+ audio_path: Path to audio file
111
+
112
+ Returns:
113
+ Transcribed text
114
+
115
+ Raises:
116
+ AudioProcessingError: If transcription fails with all providers
117
+ FileNotFoundError: If audio file doesn't exist
118
+ """
119
+ # Validate audio file exists
120
+ audio_file = Path(audio_path)
121
+ if not audio_file.exists():
122
+ raise FileNotFoundError(f"Audio file not found: {audio_path}")
123
+
124
+ # Validate audio format
125
+ if not self._validate_audio_format(audio_file):
126
+ raise AudioProcessingError(
127
+ f"Unsupported audio format: {audio_file.suffix}. "
128
+ f"Supported formats: {[fmt.value for fmt in AudioFormat]}"
129
+ )
130
+
131
+ logger.info(f"Transcribing audio file: {audio_path} using {self.provider.value}")
132
+
133
+ # Try primary provider
134
+ try:
135
+ if self.provider == STTProvider.GEMINI:
136
+ return self._transcribe_with_gemini(audio_path)
137
+ elif self.provider == STTProvider.GOOGLE_STT:
138
+ return self._transcribe_with_google_stt(audio_path)
139
+ except Exception as e:
140
+ logger.warning(f"Primary STT provider ({self.provider.value}) failed: {e}")
141
+
142
+ # Try fallback if enabled
143
+ if self.enable_fallback:
144
+ logger.info("Attempting fallback STT provider...")
145
+ try:
146
+ if self.provider == STTProvider.GEMINI:
147
+ # Fallback to Google STT
148
+ return self._transcribe_with_google_stt(audio_path)
149
+ else:
150
+ # Fallback to Gemini
151
+ return self._transcribe_with_gemini(audio_path)
152
+ except Exception as fallback_error:
153
+ logger.error(f"Fallback STT provider also failed: {fallback_error}")
154
+ raise AudioProcessingError(
155
+ f"All STT providers failed. Primary: {e}, Fallback: {fallback_error}"
156
+ ) from fallback_error
157
+ else:
158
+ raise AudioProcessingError(f"STT transcription failed: {e}") from e
159
+
160
+ def _validate_audio_format(self, audio_file: Path) -> bool:
161
+ """
162
+ Validate that audio file format is supported.
163
+
164
+ Args:
165
+ audio_file: Path to audio file
166
+
167
+ Returns:
168
+ True if format is supported, False otherwise
169
+ """
170
+ suffix = audio_file.suffix.lstrip('.').lower()
171
+ supported_formats = [fmt.value for fmt in AudioFormat]
172
+ return suffix in supported_formats
173
+
174
+ def _transcribe_with_gemini(self, audio_path: str) -> str:
175
+ """
176
+ Transcribe audio using Gemini native audio support.
177
+
178
+ Args:
179
+ audio_path: Path to audio file
180
+
181
+ Returns:
182
+ Transcribed text
183
+
184
+ Raises:
185
+ Exception: If Gemini API call fails
186
+ """
187
+ logger.debug(f"Transcribing with Gemini: {audio_path}")
188
+
189
+ try:
190
+ # Upload audio file to Gemini
191
+ audio_file = self.client.files.upload(file=audio_path)
192
+ logger.debug(f"Audio file uploaded: {audio_file.name}")
193
+
194
+ # Create prompt for transcription
195
+ prompt = (
196
+ "Transcribe this audio accurately. "
197
+ "Return only the transcribed text without any additional commentary or formatting."
198
+ )
199
+
200
+ # Generate content with audio
201
+ response = self.client.models.generate_content(
202
+ model=self.model_name,
203
+ contents=[prompt, audio_file]
204
+ )
205
+
206
+ # Extract transcribed text
207
+ if response.text is None:
208
+ logger.warning("Gemini returned None for transcription")
209
+ logger.debug(f"Response object: {response}")
210
+ # Check if there are candidates with parts
211
+ if hasattr(response, 'candidates') and response.candidates:
212
+ logger.debug(f"Response has {len(response.candidates)} candidates")
213
+ for i, candidate in enumerate(response.candidates):
214
+ logger.debug(f"Candidate {i}: {candidate}")
215
+ transcript = ""
216
+ else:
217
+ transcript = response.text.strip()
218
+
219
+ if transcript:
220
+ logger.info(f"Gemini transcription successful: '{transcript[:50]}...'")
221
+ else:
222
+ logger.warning("Gemini transcription returned empty result")
223
+
224
+ # Clean up uploaded file
225
+ try:
226
+ self.client.files.delete(name=audio_file.name)
227
+ logger.debug(f"Deleted uploaded audio file: {audio_file.name}")
228
+ except Exception as cleanup_error:
229
+ logger.warning(f"Failed to delete uploaded audio file: {cleanup_error}")
230
+
231
+ return transcript
232
+
233
+ except Exception as e:
234
+ logger.error(f"Gemini transcription failed: {type(e).__name__}: {e}")
235
+ raise
236
+
237
+ def _transcribe_with_google_stt(self, audio_path: str) -> str:
238
+ """
239
+ Transcribe audio using Google Cloud Speech-to-Text API.
240
+
241
+ Args:
242
+ audio_path: Path to audio file
243
+
244
+ Returns:
245
+ Transcribed text
246
+
247
+ Raises:
248
+ Exception: If Google STT API call fails
249
+ ImportError: If google-cloud-speech is not installed
250
+ """
251
+ logger.debug(f"Transcribing with Google STT: {audio_path}")
252
+
253
+ try:
254
+ from google.cloud import speech_v1
255
+ except ImportError:
256
+ raise ImportError(
257
+ "google-cloud-speech is not installed. "
258
+ "Install it with: pip install google-cloud-speech"
259
+ )
260
+
261
+ # Initialize Google STT client (lazy loading)
262
+ if self._google_stt_client is None:
263
+ self._google_stt_client = speech_v1.SpeechClient()
264
+ logger.debug("Google STT client initialized")
265
+
266
+ # Read audio file
267
+ with open(audio_path, "rb") as audio_file:
268
+ audio_content = audio_file.read()
269
+
270
+ # Determine audio encoding from file extension
271
+ audio_path_obj = Path(audio_path)
272
+ suffix = audio_path_obj.suffix.lstrip('.').lower()
273
+
274
+ encoding_map = {
275
+ "wav": speech_v1.RecognitionConfig.AudioEncoding.LINEAR16,
276
+ "mp3": speech_v1.RecognitionConfig.AudioEncoding.MP3,
277
+ "flac": speech_v1.RecognitionConfig.AudioEncoding.FLAC,
278
+ "ogg": speech_v1.RecognitionConfig.AudioEncoding.OGG_OPUS,
279
+ "webm": speech_v1.RecognitionConfig.AudioEncoding.WEBM_OPUS,
280
+ }
281
+
282
+ encoding = encoding_map.get(suffix, speech_v1.RecognitionConfig.AudioEncoding.LINEAR16)
283
+
284
+ # Configure recognition
285
+ audio = speech_v1.RecognitionAudio(content=audio_content)
286
+ config = speech_v1.RecognitionConfig(
287
+ encoding=encoding,
288
+ language_code=self.language_code,
289
+ enable_automatic_punctuation=True,
290
+ )
291
+
292
+ # Perform transcription
293
+ try:
294
+ response = self._google_stt_client.recognize(config=config, audio=audio)
295
+
296
+ # Extract transcript from results
297
+ if not response.results:
298
+ logger.warning("Google STT returned no results")
299
+ return ""
300
+
301
+ # Combine all alternatives (usually just one)
302
+ transcript = " ".join(
303
+ result.alternatives[0].transcript
304
+ for result in response.results
305
+ if result.alternatives
306
+ )
307
+
308
+ logger.info(f"Google STT transcription successful: '{transcript[:50]}...'")
309
+ return transcript.strip()
310
+
311
+ except Exception as e:
312
+ logger.error(f"Google STT transcription failed: {type(e).__name__}: {e}")
313
+ raise
314
+
315
+ def configure(
316
+ self,
317
+ provider: Optional[STTProvider] = None,
318
+ language_code: Optional[str] = None,
319
+ enable_fallback: Optional[bool] = None
320
+ ):
321
+ """
322
+ Reconfigure STT service settings.
323
+
324
+ Args:
325
+ provider: New STT provider to use
326
+ language_code: New language code
327
+ enable_fallback: Whether to enable fallback
328
+ """
329
+ if provider is not None:
330
+ self.provider = provider
331
+ logger.info(f"STT provider changed to: {provider.value}")
332
+
333
+ if language_code is not None:
334
+ self.language_code = language_code
335
+ logger.info(f"Language code changed to: {language_code}")
336
+
337
+ if enable_fallback is not None:
338
+ self.enable_fallback = enable_fallback
339
+ logger.info(f"Fallback {'enabled' if enable_fallback else 'disabled'}")
340
+
341
+
342
+ # Example usage
343
+ if __name__ == "__main__":
344
+ import sys
345
+
346
+ # Configure logging for testing
347
+ logging.basicConfig(
348
+ level=logging.INFO,
349
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
350
+ )
351
+
352
+ # Check for audio file argument
353
+ if len(sys.argv) < 2:
354
+ print("Usage: python -m mortis.stt_service <audio_file>")
355
+ print("Example: python -m mortis.stt_service test_audio.wav")
356
+ sys.exit(1)
357
+
358
+ audio_file = sys.argv[1]
359
+
360
+ try:
361
+ # Create STT service
362
+ stt_service = STTService()
363
+
364
+ # Transcribe audio
365
+ print(f"\nTranscribing: {audio_file}")
366
+ print("-" * 60)
367
+ transcript = stt_service.transcribe(audio_file)
368
+ print(f"Transcript: {transcript}")
369
+ print("-" * 60)
370
+
371
+ except FileNotFoundError as e:
372
+ print(f"Error: {e}")
373
+ sys.exit(1)
374
+ except AudioProcessingError as e:
375
+ print(f"Audio processing error: {e}")
376
+ sys.exit(1)
377
+ except ValueError as e:
378
+ print(f"Configuration error: {e}")
379
+ print("Please set GEMINI_API_KEY in your .env file")
380
+ sys.exit(1)
381
+ except Exception as e:
382
+ print(f"Unexpected error: {type(e).__name__}: {e}")
383
+ sys.exit(1)
src/mortis/tools.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM integration for Mortis conversational AI.
3
+
4
+ This module provides the ask_mortis() function that integrates with the Gemini API
5
+ to generate character-driven responses and coordinate gesture execution.
6
+ """
7
+
8
+ import logging
9
+ import time
10
+ from typing import Tuple, Optional
11
+ from pathlib import Path
12
+
13
+ from .robot import MortisArm
14
+ from .gemini_client import GeminiClient
15
+ from .models import GeminiResponse
16
+
17
+ # Configure logging
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Global instances
21
+ mortis_arm = MortisArm()
22
+ gemini_client = None # Lazy initialization
23
+ stt_service = None # Lazy initialization
24
+ tts_service = None # Lazy initialization
25
+ intent_router = None # Lazy initialization
26
+ smolvla_executor = None # Lazy initialization
27
+
28
+
29
+ def _get_gemini_client() -> GeminiClient:
30
+ """
31
+ Get or create the global GeminiClient instance.
32
+
33
+ Returns:
34
+ GeminiClient instance
35
+ """
36
+ global gemini_client
37
+ if gemini_client is None:
38
+ gemini_client = GeminiClient()
39
+ logger.info("GeminiClient initialized")
40
+ return gemini_client
41
+
42
+
43
+ def _get_stt_service():
44
+ """
45
+ Get or create the global STTService instance.
46
+
47
+ Returns:
48
+ STTService instance
49
+ """
50
+ global stt_service
51
+ if stt_service is None:
52
+ from .stt_service import STTService
53
+ stt_service = STTService()
54
+ logger.info("STTService initialized")
55
+ return stt_service
56
+
57
+
58
+ def _get_tts_service():
59
+ """
60
+ Get or create the global TTSService instance.
61
+
62
+ Returns:
63
+ TTSService instance
64
+ """
65
+ global tts_service
66
+ if tts_service is None:
67
+ from .tts_service import get_tts_service
68
+ tts_service = get_tts_service()
69
+ logger.info("TTSService initialized")
70
+ return tts_service
71
+
72
+
73
+ def _get_intent_router():
74
+ """
75
+ Get or create the global IntentRouter instance.
76
+
77
+ Returns:
78
+ IntentRouter instance
79
+ """
80
+ global intent_router
81
+ if intent_router is None:
82
+ from .intent_router import IntentRouter
83
+ intent_router = IntentRouter()
84
+ logger.info("IntentRouter initialized")
85
+ return intent_router
86
+
87
+
88
+ def _get_smolvla_executor():
89
+ """
90
+ Get or create the global SmolVLAExecutor instance.
91
+
92
+ Returns:
93
+ SmolVLAExecutor instance or None if not configured
94
+ """
95
+ global smolvla_executor
96
+ if smolvla_executor is None:
97
+ import os
98
+
99
+ # Check if we're in simulation mode
100
+ robot_mode = os.getenv("ROBOT_MODE", "physical").lower()
101
+ if robot_mode == "simulation":
102
+ logger.info("SmolVLA disabled in simulation mode")
103
+ smolvla_executor = None
104
+ return None
105
+
106
+ checkpoint_path = os.getenv("SMOLVLA_CHECKPOINT_PATH")
107
+
108
+ if checkpoint_path:
109
+ try:
110
+ from .smolvla_executor import SmolVLAExecutor
111
+ smolvla_executor = SmolVLAExecutor(
112
+ checkpoint_path=checkpoint_path,
113
+ robot_arm=mortis_arm
114
+ )
115
+ logger.info(f"SmolVLAExecutor initialized with checkpoint: {checkpoint_path}")
116
+ except Exception as e:
117
+ logger.warning(f"Failed to initialize SmolVLAExecutor: {e}")
118
+ logger.warning("Manipulation commands will fall back to gestures")
119
+ smolvla_executor = None
120
+ else:
121
+ logger.info("SMOLVLA_CHECKPOINT_PATH not set, manipulation commands will use gestures")
122
+ smolvla_executor = None
123
+
124
+ return smolvla_executor
125
+
126
+
127
+ def ask_mortis(
128
+ user_msg: Optional[str] = None,
129
+ model_name: Optional[str] = None,
130
+ audio_path: Optional[str] = None
131
+ ) -> Tuple[str, str, str]:
132
+ """
133
+ Send user message to Gemini API and get Mortis response with gesture.
134
+
135
+ This function supports both text and voice input through a unified interface.
136
+ It implements the complete voice-to-text-to-Gemini-to-TTS pipeline with
137
+ latency monitoring.
138
+
139
+ Processing flow:
140
+ 1. If audio_path provided, transcribe to text using STT
141
+ 2. Connect to robot arm if not already connected
142
+ 3. Send text message to Gemini API
143
+ 4. Parse structured JSON response
144
+ 5. Return message, mood, and gesture for execution
145
+
146
+ Args:
147
+ user_msg: User's input message text (optional if audio_path provided)
148
+ model_name: Optional Gemini model name (uses default from env if not provided)
149
+ audio_path: Optional path to audio file for voice input
150
+
151
+ Returns:
152
+ Tuple of (message, mood, gesture) where:
153
+ - message: Text response from Mortis
154
+ - mood: Emotional mood (e.g., "ominous", "playful")
155
+ - gesture: Gesture to execute (e.g., "wave", "idle")
156
+
157
+ Raises:
158
+ ValueError: If neither user_msg nor audio_path is provided
159
+
160
+ Note:
161
+ This function maintains backward compatibility with the previous API.
162
+ The gesture is returned but not automatically executed - the caller
163
+ is responsible for executing the gesture via mortis_arm.move_arm().
164
+
165
+ Latency monitoring logs are generated for voice processing pipeline.
166
+ """
167
+ pipeline_start = time.time()
168
+
169
+ # Validate input
170
+ if user_msg is None and audio_path is None:
171
+ raise ValueError("Either user_msg or audio_path must be provided")
172
+
173
+ # Voice input processing
174
+ if audio_path is not None:
175
+ logger.info(f"🎤 Processing voice input from: {audio_path}")
176
+ stt_start = time.time()
177
+
178
+ try:
179
+ # Get STT service
180
+ stt = _get_stt_service()
181
+
182
+ # Transcribe audio to text
183
+ user_msg = stt.transcribe(audio_path)
184
+
185
+ stt_latency = time.time() - stt_start
186
+ logger.info(f"⏱️ STT latency: {stt_latency:.2f}s")
187
+ logger.info(f"📝 Transcribed: '{user_msg[:50]}...'")
188
+
189
+ if not user_msg or not user_msg.strip():
190
+ logger.warning("⚠️ STT returned empty transcription")
191
+ return "I couldn't hear you... speak again.", "nervous", "idle"
192
+
193
+ except Exception as e:
194
+ logger.error(f"❌ Voice input processing failed: {e}")
195
+ return "The spirits couldn't understand... try again.", "ominous", "idle"
196
+
197
+ # Ensure robot is connected
198
+ if not mortis_arm.connected:
199
+ try:
200
+ mortis_arm.connect()
201
+ logger.info("Robot arm connected")
202
+ except Exception as e:
203
+ logger.error(f"Failed to connect to robot arm: {e}")
204
+ # Continue anyway - we can still generate responses
205
+
206
+ # Get Gemini client
207
+ client = _get_gemini_client()
208
+
209
+ # Reconfigure model if specified
210
+ if model_name:
211
+ client.configure_model(model_name=model_name)
212
+ logger.info(f"Using Gemini model: {model_name}")
213
+
214
+ # Send message to Gemini
215
+ logger.info(f"💬 Asking Mortis: {user_msg[:50]}...")
216
+ gemini_start = time.time()
217
+
218
+ response_json = client.send_message(user_msg)
219
+
220
+ gemini_latency = time.time() - gemini_start
221
+ logger.info(f"⏱️ Gemini latency: {gemini_latency:.2f}s")
222
+
223
+ # Parse response using IntentRouter
224
+ try:
225
+ # Get intent router
226
+ router = _get_intent_router()
227
+
228
+ # Parse Gemini response into Intent
229
+ intent = router.parse_gemini_response(response_json)
230
+
231
+ # Extract fields for return
232
+ message = intent.message
233
+ mood = intent.mood
234
+ gesture = intent.gesture if intent.gesture else "idle"
235
+
236
+ # Route based on intent type
237
+ execution_path = router.route_intent(intent)
238
+
239
+ if execution_path == "manipulation":
240
+ # Valid manipulation command - attempt SmolVLA execution
241
+ logger.info(f"🤖 Manipulation command detected: '{intent.command}'")
242
+
243
+ # Try to get SmolVLA executor
244
+ executor = _get_smolvla_executor()
245
+
246
+ if executor is not None:
247
+ try:
248
+ # Execute manipulation task
249
+ logger.info(f"Executing manipulation task: {intent.command}")
250
+ success = executor.execute(intent.command)
251
+
252
+ if success:
253
+ logger.info(f"✅ Manipulation task completed successfully")
254
+ else:
255
+ logger.warning(f"⚠️ Manipulation task did not complete fully")
256
+
257
+ # Return with "manipulation" as gesture to indicate manipulation was executed
258
+ gesture = "manipulation"
259
+
260
+ except Exception as e:
261
+ logger.error(f"❌ SmolVLA execution failed: {e}")
262
+ logger.info("Falling back to gesture execution")
263
+
264
+ # Fallback to gesture execution
265
+ gesture = "idle"
266
+ if mortis_arm.connected:
267
+ mortis_arm.move_arm(gesture)
268
+ else:
269
+ # No SmolVLA executor available, fall back to gesture
270
+ logger.warning("SmolVLA executor not available, falling back to gesture")
271
+ gesture = "idle"
272
+ if mortis_arm.connected:
273
+ mortis_arm.move_arm(gesture)
274
+
275
+ elif execution_path == "gesture":
276
+ # Conversational response with gesture
277
+ logger.info(f"💬 Conversation with gesture: {gesture}")
278
+
279
+ # Execute gesture immediately
280
+ if mortis_arm.connected:
281
+ try:
282
+ mortis_arm.move_arm(gesture)
283
+ except Exception as e:
284
+ logger.error(f"Failed to execute gesture '{gesture}': {e}")
285
+
286
+ elif execution_path == "invalid":
287
+ # Invalid intent - fall back to gesture
288
+ logger.warning(f"⚠️ Invalid intent: {intent.validation_error}")
289
+ logger.info("Falling back to conversational gesture")
290
+
291
+ # Use gesture from intent or default to idle
292
+ gesture = intent.gesture if intent.gesture else "idle"
293
+
294
+ # Execute gesture
295
+ if mortis_arm.connected:
296
+ try:
297
+ mortis_arm.move_arm(gesture)
298
+ except Exception as e:
299
+ logger.error(f"Failed to execute fallback gesture '{gesture}': {e}")
300
+
301
+ # Calculate total pipeline latency
302
+ total_latency = time.time() - pipeline_start
303
+ logger.info(f"⏱️ Total pipeline latency: {total_latency:.2f}s")
304
+ logger.info(f"👻 Mortis responds (path: {execution_path}, mood: {mood}, gesture: {gesture})")
305
+
306
+ return message, mood, gesture
307
+
308
+ except (ValueError, KeyError) as e:
309
+ # If parsing fails, return safe defaults
310
+ logger.error(f"Failed to parse Gemini response: {e}")
311
+ logger.error(f"Response JSON: {response_json}")
312
+
313
+ # Return fallback response
314
+ return "The spirits are confused... try again.", "ominous", "idle"
315
+
316
+
317
+ def ask_mortis_with_voice(
318
+ user_msg: Optional[str] = None,
319
+ model_name: Optional[str] = None,
320
+ audio_path: Optional[str] = None,
321
+ generate_audio: bool = True
322
+ ) -> Tuple[str, str, str, Optional[str]]:
323
+ """
324
+ Complete voice-to-text-to-Gemini-to-TTS pipeline with audio output.
325
+
326
+ This is a convenience function that wraps ask_mortis() and adds TTS
327
+ generation for the response. It provides the full multi-modal experience.
328
+
329
+ Args:
330
+ user_msg: User's input message text (optional if audio_path provided)
331
+ model_name: Optional Gemini model name
332
+ audio_path: Optional path to audio file for voice input
333
+ generate_audio: Whether to generate audio output (default: True)
334
+
335
+ Returns:
336
+ Tuple of (message, mood, gesture, audio_path) where:
337
+ - message: Text response from Mortis
338
+ - mood: Emotional mood
339
+ - gesture: Gesture to execute
340
+ - audio_path: Path to generated audio file (None if generation fails)
341
+
342
+ Note:
343
+ This function logs latency for the complete voice processing pipeline
344
+ including STT, Gemini inference, and TTS generation.
345
+ """
346
+ pipeline_start = time.time()
347
+
348
+ # Get text response from Gemini (handles STT if audio_path provided)
349
+ message, mood, gesture = ask_mortis(
350
+ user_msg=user_msg,
351
+ model_name=model_name,
352
+ audio_path=audio_path
353
+ )
354
+
355
+ # Generate audio response if requested
356
+ response_audio_path = None
357
+ if generate_audio:
358
+ tts_start = time.time()
359
+
360
+ try:
361
+ # Get TTS service
362
+ tts = _get_tts_service()
363
+
364
+ # Generate audio
365
+ response_audio_path = tts.synthesize(message)
366
+
367
+ tts_latency = time.time() - tts_start
368
+ logger.info(f"⏱️ TTS latency: {tts_latency:.2f}s")
369
+
370
+ if response_audio_path:
371
+ logger.info(f"🔊 Audio generated: {response_audio_path}")
372
+ else:
373
+ logger.warning("⚠️ TTS returned None")
374
+
375
+ except Exception as e:
376
+ logger.error(f"❌ TTS generation failed: {e}")
377
+ # Continue without audio - text response is still valid
378
+
379
+ # Log total pipeline latency including TTS
380
+ total_latency = time.time() - pipeline_start
381
+ logger.info(f"⏱️ Complete voice pipeline latency: {total_latency:.2f}s")
382
+
383
+ return message, mood, gesture, response_audio_path
384
+
385
+
386
+
387
+ if __name__ == "__main__":
388
+ # Configure logging for testing
389
+ logging.basicConfig(level=logging.INFO)
390
+
391
+ # Test conversational interactions
392
+ print("=== Test 1: Greeting ===")
393
+ message, mood, gesture = ask_mortis("Mortis, someone is entering the lab… act!")
394
+ print(f"Message: {message}")
395
+ print(f"Mood: {mood}")
396
+ print(f"Gesture: {gesture}")
397
+ print()
398
+
399
+ print("=== Test 2: Introduction ===")
400
+ message, mood, gesture = ask_mortis("Introduce yourself with a sinister bow.")
401
+ print(f"Message: {message}")
402
+ print(f"Mood: {mood}")
403
+ print(f"Gesture: {gesture}")
404
+ print()
405
+
406
+ print("=== Test 3: Action sequence ===")
407
+ message, mood, gesture = ask_mortis("Grab the cursed vial and then release it.")
408
+ print(f"Message: {message}")
409
+ print(f"Mood: {mood}")
410
+ print(f"Gesture: {gesture}")
411
+ print()
412
+
413
+ print("=== Test 4: Manipulation command ===")
414
+ message, mood, gesture = ask_mortis("Can you move the skull to the green cup?")
415
+ print(f"Message: {message}")
416
+ print(f"Mood: {mood}")
417
+ print(f"Gesture: {gesture}")
418
+ print()
src/mortis/tts_service.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Text-to-Speech service for Mortis voice output.
3
+
4
+ Provides TTS capabilities using Google Cloud Text-to-Speech API with
5
+ fallback to local gTTS for offline scenarios.
6
+ """
7
+
8
+ import os
9
+ import time
10
+ import logging
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class TTSService:
18
+ """
19
+ Text-to-Speech service for converting Mortis responses to audio.
20
+
21
+ Uses Google Cloud TTS as primary service with gTTS as fallback.
22
+ Configured for a deep, ominous voice suitable for Mortis character.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ output_dir: str = "outputs",
28
+ use_google_tts: bool = True,
29
+ voice_name: str = "en-US-Neural2-D",
30
+ speaking_rate: float = 0.9,
31
+ pitch: float = -2.0
32
+ ):
33
+ """
34
+ Initialize TTS service.
35
+
36
+ Args:
37
+ output_dir: Directory for generated audio files
38
+ use_google_tts: Whether to use Google Cloud TTS (requires credentials)
39
+ voice_name: Google TTS voice name (Neural2-D is deep male voice)
40
+ speaking_rate: Speech speed (0.9 = slightly slower for ominous effect)
41
+ pitch: Voice pitch (-2.0 = lower for spooky voice)
42
+ """
43
+ self.output_dir = Path(output_dir)
44
+ self.output_dir.mkdir(parents=True, exist_ok=True)
45
+
46
+ self.use_google_tts = use_google_tts
47
+ self.voice_name = voice_name
48
+ self.speaking_rate = speaking_rate
49
+ self.pitch = pitch
50
+
51
+ # Try to initialize Google TTS client
52
+ self.google_client = None
53
+ self.texttospeech = None
54
+ if self.use_google_tts:
55
+ try:
56
+ from google.cloud import texttospeech
57
+ self.google_client = texttospeech.TextToSpeechClient()
58
+ self.texttospeech = texttospeech
59
+ logger.info("Google Cloud TTS initialized successfully")
60
+ except ImportError as e:
61
+ logger.warning(f"Google Cloud TTS not available: {e}. Will use gTTS fallback.")
62
+ self.use_google_tts = False
63
+ except Exception as e:
64
+ logger.warning(f"Failed to initialize Google TTS: {e}. Will use gTTS fallback.")
65
+ self.use_google_tts = False
66
+
67
+ logger.info(f"TTS Service initialized (Google TTS: {self.use_google_tts})")
68
+
69
+ def synthesize(self, text: str, filename: Optional[str] = None) -> Optional[str]:
70
+ """
71
+ Convert text to speech audio file.
72
+
73
+ Args:
74
+ text: Text to convert to speech
75
+ filename: Optional custom filename (without extension)
76
+
77
+ Returns:
78
+ Path to generated audio file, or None if synthesis fails
79
+ """
80
+ if not text or not text.strip():
81
+ logger.warning("Empty text provided to TTS service")
82
+ return None
83
+
84
+ # Generate filename if not provided
85
+ if filename is None:
86
+ timestamp = int(time.time() * 1000)
87
+ filename = f"mortis_response_{timestamp}"
88
+
89
+ # Try Google TTS first
90
+ if self.use_google_tts and self.google_client:
91
+ try:
92
+ audio_path = self._synthesize_google_tts(text, filename)
93
+ logger.info(f"Generated audio with Google TTS: {audio_path}")
94
+ return audio_path
95
+ except Exception as e:
96
+ logger.error(f"Google TTS failed: {e}. Falling back to gTTS.")
97
+
98
+ # Fallback to gTTS
99
+ try:
100
+ audio_path = self._synthesize_gtts(text, filename)
101
+ logger.info(f"Generated audio with gTTS: {audio_path}")
102
+ return audio_path
103
+ except Exception as e:
104
+ logger.error(f"gTTS also failed: {e}. No audio generated.")
105
+ return None
106
+
107
+ def _synthesize_google_tts(self, text: str, filename: str) -> str:
108
+ """
109
+ Synthesize speech using Google Cloud TTS.
110
+
111
+ Args:
112
+ text: Text to synthesize
113
+ filename: Base filename (without extension)
114
+
115
+ Returns:
116
+ Path to generated MP3 file
117
+ """
118
+ # Prepare synthesis input
119
+ synthesis_input = self.texttospeech.SynthesisInput(text=text)
120
+
121
+ # Configure voice parameters for Mortis character
122
+ voice = self.texttospeech.VoiceSelectionParams(
123
+ language_code="en-US",
124
+ name=self.voice_name,
125
+ ssml_gender=self.texttospeech.SsmlVoiceGender.MALE
126
+ )
127
+
128
+ # Configure audio output
129
+ audio_config = self.texttospeech.AudioConfig(
130
+ audio_encoding=self.texttospeech.AudioEncoding.MP3,
131
+ speaking_rate=self.speaking_rate,
132
+ pitch=self.pitch
133
+ )
134
+
135
+ # Perform synthesis
136
+ response = self.google_client.synthesize_speech(
137
+ input=synthesis_input,
138
+ voice=voice,
139
+ audio_config=audio_config
140
+ )
141
+
142
+ # Save audio file
143
+ output_path = self.output_dir / f"{filename}.mp3"
144
+ with open(output_path, "wb") as out:
145
+ out.write(response.audio_content)
146
+
147
+ return str(output_path)
148
+
149
+ def _synthesize_gtts(self, text: str, filename: str) -> str:
150
+ """
151
+ Synthesize speech using gTTS (local fallback).
152
+
153
+ Args:
154
+ text: Text to synthesize
155
+ filename: Base filename (without extension)
156
+
157
+ Returns:
158
+ Path to generated MP3 file
159
+ """
160
+ from gtts import gTTS
161
+
162
+ # Create TTS object with slower speech for ominous effect
163
+ tts = gTTS(text=text, lang='en', slow=True)
164
+
165
+ # Save audio file
166
+ output_path = self.output_dir / f"{filename}.mp3"
167
+ tts.save(str(output_path))
168
+
169
+ return str(output_path)
170
+
171
+ def cleanup_old_files(self, max_age_seconds: int = 3600):
172
+ """
173
+ Remove old audio files to prevent disk space issues.
174
+
175
+ Args:
176
+ max_age_seconds: Maximum age of files to keep (default: 1 hour)
177
+ """
178
+ current_time = time.time()
179
+ removed_count = 0
180
+
181
+ for audio_file in self.output_dir.glob("mortis_response_*.mp3"):
182
+ try:
183
+ file_age = current_time - audio_file.stat().st_mtime
184
+ if file_age > max_age_seconds:
185
+ audio_file.unlink()
186
+ removed_count += 1
187
+ except Exception as e:
188
+ logger.warning(f"Failed to remove old file {audio_file}: {e}")
189
+
190
+ if removed_count > 0:
191
+ logger.info(f"Cleaned up {removed_count} old audio files")
192
+
193
+
194
+ # Global TTS service instance
195
+ _tts_service: Optional[TTSService] = None
196
+
197
+
198
+ def get_tts_service() -> TTSService:
199
+ """
200
+ Get or create global TTS service instance.
201
+
202
+ Returns:
203
+ Singleton TTSService instance
204
+ """
205
+ global _tts_service
206
+ if _tts_service is None:
207
+ # Check if Google Cloud credentials are available
208
+ use_google = bool(os.getenv("GOOGLE_APPLICATION_CREDENTIALS"))
209
+ _tts_service = TTSService(use_google_tts=use_google)
210
+ return _tts_service
211
+
212
+
213
+ def synthesize_speech(text: str, filename: Optional[str] = None) -> Optional[str]:
214
+ """
215
+ Convenience function to synthesize speech using global TTS service.
216
+
217
+ Args:
218
+ text: Text to convert to speech
219
+ filename: Optional custom filename
220
+
221
+ Returns:
222
+ Path to generated audio file, or None if synthesis fails
223
+ """
224
+ service = get_tts_service()
225
+ return service.synthesize(text, filename)