ArseniyPerchik commited on
Commit
8e6ca7f
·
1 Parent(s): 4b49a7c
Files changed (3) hide show
  1. app.py +33 -36
  2. draft_1.ipynb +0 -258
  3. draft_tts.py +21 -0
app.py CHANGED
@@ -170,47 +170,44 @@ def get_target_from_request(request_text):
170
  return 'No goal found.'
171
 
172
 
173
- def create_demo():
174
  # main blocks
175
- with gr.Blocks(css=custom_css) as my_demo:
176
- gr.Markdown("# Agent Control with Language")
177
- gr.Markdown('Say the agent where to go and what to do')
178
-
179
- with gr.Row():
180
- with gr.Column():
181
- request_audio = gr.Microphone(editable=False)
182
- # send_btn = gr.Button(value='Send Request')
183
- request_text = gr.Textbox(label="Request:", lines=2, interactive=False)
184
- request_target = gr.Textbox(label='Target:', lines=2)
185
- status = gr.Textbox(label='Status:', lines=2, elem_id="mytextbox")
186
- with gr.Column():
187
- output_env = gr.Video(label="Env:", autoplay=True)
188
- with gr.Accordion("TODO List", open=False):
189
- gr.Markdown("""
190
- ## PLAN
191
- - [x] to use audio as an input for requests
192
- - [x] to learn a policy for navigation from location to location
193
- - [x] to build an interface that will show the status of the request
194
- - [ ] to incorporate a longer chain of goals; for example, go there and pick the package, then come back
195
- - [ ] to introduce additional learnt capabilities
196
- - [ ] to build more complex environments where the movement is not so straightforward
197
- """)
198
-
199
- # EVENTS:
200
- # gr.on(triggers=["load"], fn=load_image_on_start, outputs=output_env_image)
201
- # my_demo.load(fn=load_image_on_start, outputs=output_env_image)
202
- my_demo.load(fn=create_standing_animation, outputs=output_env)
203
- # request_audio.stream(fn=get_text_request, inputs=request_audio, outputs=request_text)
204
- request_audio.stop_recording(fn=get_text_request, inputs=request_audio, outputs=request_text)
205
- request_text.change(fn=get_target_from_request, inputs=request_text, outputs=request_target)
206
- request_target.change(fn=move_agent, inputs=request_target, outputs=[output_env, status])
207
- request_audio.stop_recording(lambda: None, outputs=request_audio)
208
- return my_demo
209
 
210
  # ---------------------------- #
211
  # main
212
  # ---------------------------- #
213
- demo = create_demo()
214
  demo.launch()
215
 
216
 
 
170
  return 'No goal found.'
171
 
172
 
 
173
  # main blocks
174
+ with gr.Blocks(css=custom_css) as demo:
175
+ gr.Markdown("# Agent Control with Language")
176
+ gr.Markdown('Say the agent where to go and what to do')
177
+
178
+ with gr.Row():
179
+ with gr.Column():
180
+ request_audio = gr.Microphone(editable=False)
181
+ # send_btn = gr.Button(value='Send Request')
182
+ request_text = gr.Textbox(label="Request:", lines=2, interactive=False)
183
+ request_target = gr.Textbox(label='Target:', lines=2)
184
+ status = gr.Textbox(label='Status:', lines=2, elem_id="mytextbox")
185
+ with gr.Column():
186
+ output_env = gr.Video(label="Env:", autoplay=True)
187
+ with gr.Accordion("TODO List", open=False):
188
+ gr.Markdown("""
189
+ ## PLAN
190
+ - [x] to use audio as an input for requests
191
+ - [x] to learn a policy for navigation from location to location
192
+ - [x] to build an interface that will show the status of the request
193
+ - [ ] to incorporate a longer chain of goals; for example, go there and pick the package, then come back
194
+ - [ ] to introduce additional learnt capabilities
195
+ - [ ] to build more complex environments where the movement is not so straightforward
196
+ """)
197
+
198
+ # EVENTS:
199
+ # gr.on(triggers=["load"], fn=load_image_on_start, outputs=output_env_image)
200
+ # my_demo.load(fn=load_image_on_start, outputs=output_env_image)
201
+ demo.load(fn=create_standing_animation, outputs=output_env)
202
+ # request_audio.stream(fn=get_text_request, inputs=request_audio, outputs=request_text)
203
+ request_audio.stop_recording(fn=get_text_request, inputs=request_audio, outputs=request_text)
204
+ request_text.change(fn=get_target_from_request, inputs=request_text, outputs=request_target)
205
+ request_target.change(fn=move_agent, inputs=request_target, outputs=[output_env, status])
206
+ request_audio.stop_recording(lambda: None, outputs=request_audio)
 
207
 
208
  # ---------------------------- #
209
  # main
210
  # ---------------------------- #
 
211
  demo.launch()
212
 
213
 
draft_1.ipynb DELETED
@@ -1,258 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "metadata": {},
5
- "cell_type": "markdown",
6
- "source": "## FIRST CHECK",
7
- "id": "518bcf10bfff3063"
8
- },
9
- {
10
- "metadata": {
11
- "ExecuteTime": {
12
- "end_time": "2025-04-21T15:45:34.883735Z",
13
- "start_time": "2025-04-21T15:45:33.734296Z"
14
- }
15
- },
16
- "cell_type": "code",
17
- "source": [
18
- "# gradio app.py --watch-dirs app.py\n",
19
- "\n",
20
- "import gradio as gr\n",
21
- "import numpy as np\n",
22
- "import matplotlib.pyplot as plt\n",
23
- "import matplotlib.animation as animation\n",
24
- "import tempfile\n",
25
- "import torch\n",
26
- "from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline\n",
27
- "import torchaudio\n",
28
- "import torchaudio.transforms as T\n",
29
- "from matplotlib.patches import Circle\n",
30
- "from stable_baselines3 import SAC\n",
31
- "from warehouse_env import WarehouseEnv\n",
32
- "from types import SimpleNamespace"
33
- ],
34
- "id": "f861a8e81b92bceb",
35
- "outputs": [],
36
- "execution_count": 50
37
- },
38
- {
39
- "metadata": {
40
- "ExecuteTime": {
41
- "end_time": "2025-04-21T15:45:58.508916Z",
42
- "start_time": "2025-04-21T15:45:53.686659Z"
43
- }
44
- },
45
- "cell_type": "code",
46
- "source": "asr_pipe_default = pipeline(\"automatic-speech-recognition\")",
47
- "id": "90ddbbf24fac7b1f",
48
- "outputs": [
49
- {
50
- "name": "stderr",
51
- "output_type": "stream",
52
- "text": [
53
- "No model was supplied, defaulted to facebook/wav2vec2-base-960h and revision 22aad52 (https://huggingface.co/facebook/wav2vec2-base-960h).\n",
54
- "Using a pipeline without specifying a model name and revision in production is not recommended.\n",
55
- "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']\n",
56
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
57
- "Device set to use mps:0\n"
58
- ]
59
- }
60
- ],
61
- "execution_count": 51
62
- },
63
- {
64
- "metadata": {
65
- "ExecuteTime": {
66
- "end_time": "2025-04-21T15:46:03.873405Z",
67
- "start_time": "2025-04-21T15:46:02.219145Z"
68
- }
69
- },
70
- "cell_type": "code",
71
- "source": [
72
- "\n",
73
- "waveform, sample_rate = torchaudio.load(\"sample.wav\")\n",
74
- "target_sr = 16000\n",
75
- "resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sr, dtype=waveform.dtype)\n",
76
- "waveform = resampler(waveform)\n",
77
- "waveform_np = waveform.squeeze().numpy()\n",
78
- "# sample = dataset[2][\"audio\"]\n",
79
- "\n",
80
- "# result = pipe(sample)\n",
81
- "result = asr_pipe_default(waveform_np)\n",
82
- "print(result[\"text\"])\n"
83
- ],
84
- "id": "75dbfd85403eb511",
85
- "outputs": [
86
- {
87
- "name": "stdout",
88
- "output_type": "stream",
89
- "text": [
90
- "THIS IS A SIMPLE TEXT\n"
91
- ]
92
- }
93
- ],
94
- "execution_count": 52
95
- },
96
- {
97
- "metadata": {},
98
- "cell_type": "markdown",
99
- "source": "## TO SAVE THE MODEL",
100
- "id": "e0a9c2fd7bce280a"
101
- },
102
- {
103
- "metadata": {
104
- "ExecuteTime": {
105
- "end_time": "2025-04-21T15:51:20.114613Z",
106
- "start_time": "2025-04-21T15:51:20.106995Z"
107
- }
108
- },
109
- "cell_type": "code",
110
- "source": "save_dir = './models_for_proj/wav2vec2-base-960h'",
111
- "id": "10f2808d5da846b9",
112
- "outputs": [],
113
- "execution_count": 53
114
- },
115
- {
116
- "metadata": {
117
- "ExecuteTime": {
118
- "end_time": "2025-04-21T15:54:16.050333Z",
119
- "start_time": "2025-04-21T15:54:12.432304Z"
120
- }
121
- },
122
- "cell_type": "code",
123
- "source": [
124
- "from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor\n",
125
- "\n",
126
- "# Load pretrained model and processor\n",
127
- "model = Wav2Vec2ForCTC.from_pretrained(\"facebook/wav2vec2-base-960h\")\n",
128
- "processor = Wav2Vec2Processor.from_pretrained(\"facebook/wav2vec2-base-960h\")\n",
129
- "\n",
130
- "# Save locally\n",
131
- "model.save_pretrained(save_dir)\n",
132
- "processor.save_pretrained(save_dir)"
133
- ],
134
- "id": "c22c64edf17613a0",
135
- "outputs": [
136
- {
137
- "name": "stderr",
138
- "output_type": "stream",
139
- "text": [
140
- "Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']\n",
141
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
142
- ]
143
- },
144
- {
145
- "data": {
146
- "text/plain": [
147
- "[]"
148
- ]
149
- },
150
- "execution_count": 57,
151
- "metadata": {},
152
- "output_type": "execute_result"
153
- }
154
- ],
155
- "execution_count": 57
156
- },
157
- {
158
- "metadata": {},
159
- "cell_type": "markdown",
160
- "source": "## TO REUSE IT",
161
- "id": "b2e0767904efbbb3"
162
- },
163
- {
164
- "metadata": {
165
- "ExecuteTime": {
166
- "end_time": "2025-04-21T15:59:35.714597Z",
167
- "start_time": "2025-04-21T15:59:35.705418Z"
168
- }
169
- },
170
- "cell_type": "code",
171
- "source": [
172
- "import torchaudio\n",
173
- "import torchaudio.transforms as T\n",
174
- "\n",
175
- "waveform, sample_rate = torchaudio.load(\"sample.wav\")\n",
176
- "target_sr = 16000\n",
177
- "resampler = T.Resample(orig_freq=sample_rate, new_freq=target_sr, dtype=waveform.dtype)\n",
178
- "waveform = resampler(waveform)\n",
179
- "waveform_np = waveform.squeeze().numpy()"
180
- ],
181
- "id": "394c5b342a6510",
182
- "outputs": [],
183
- "execution_count": 61
184
- },
185
- {
186
- "metadata": {
187
- "ExecuteTime": {
188
- "end_time": "2025-04-21T15:59:36.498222Z",
189
- "start_time": "2025-04-21T15:59:36.361763Z"
190
- }
191
- },
192
- "cell_type": "code",
193
- "source": [
194
- "import torch\n",
195
- "from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor\n",
196
- "\n",
197
- "save_dir = './models_for_proj/wav2vec2-base-960h'\n",
198
- "\n",
199
- "# load\n",
200
- "model = Wav2Vec2ForCTC.from_pretrained(save_dir)\n",
201
- "processor = Wav2Vec2Processor.from_pretrained(save_dir)\n",
202
- "\n",
203
- "# Preprocess\n",
204
- "inputs = processor(waveform_np, sampling_rate=16000, return_tensors=\"pt\", padding=True)\n",
205
- "\n",
206
- "# Inference\n",
207
- "with torch.no_grad():\n",
208
- " logits = model(**inputs).logits\n",
209
- "\n",
210
- "# Decode\n",
211
- "predicted_ids = torch.argmax(logits, dim=-1)\n",
212
- "transcription = processor.decode(predicted_ids[0])\n",
213
- "\n",
214
- "print(\"Transcription:\", transcription)\n"
215
- ],
216
- "id": "af430cf9e1e42318",
217
- "outputs": [
218
- {
219
- "name": "stdout",
220
- "output_type": "stream",
221
- "text": [
222
- "Transcription: THIS IS A SIMPLE TEXT\n"
223
- ]
224
- }
225
- ],
226
- "execution_count": 62
227
- },
228
- {
229
- "metadata": {},
230
- "cell_type": "code",
231
- "outputs": [],
232
- "execution_count": null,
233
- "source": "",
234
- "id": "113500626c003f89"
235
- }
236
- ],
237
- "metadata": {
238
- "kernelspec": {
239
- "display_name": "Python 3",
240
- "language": "python",
241
- "name": "python3"
242
- },
243
- "language_info": {
244
- "codemirror_mode": {
245
- "name": "ipython",
246
- "version": 2
247
- },
248
- "file_extension": ".py",
249
- "mimetype": "text/x-python",
250
- "name": "python",
251
- "nbconvert_exporter": "python",
252
- "pygments_lexer": "ipython2",
253
- "version": "2.7.6"
254
- }
255
- },
256
- "nbformat": 4,
257
- "nbformat_minor": 5
258
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
draft_tts.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoProcessor, BarkModel
2
+ import torch
3
+ import scipy.io.wavfile
4
+
5
+ # Load model and processor
6
+ processor = AutoProcessor.from_pretrained("suno/bark")
7
+ model = BarkModel.from_pretrained("suno/bark")
8
+
9
+ # Input text
10
+ text = "Hello! This is Bark speaking from Hugging Face."
11
+
12
+ # Prepare inputs
13
+ inputs = processor(text, return_tensors="pt")
14
+
15
+ # Generate audio
16
+ with torch.no_grad():
17
+ audio = model.generate(**inputs)
18
+
19
+ # Save the waveform
20
+ audio = audio.cpu().numpy().squeeze()
21
+ scipy.io.wavfile.write("bark_output.wav", rate=22050, data=audio)