File size: 8,426 Bytes
25a1345
 
c0f0687
fbd53e3
 
 
 
25a1345
 
d250443
25a1345
 
 
 
 
 
 
 
 
 
 
 
 
d250443
 
 
 
25a1345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0f0687
560834a
 
 
 
 
 
 
25a1345
 
 
 
 
 
d250443
25a1345
 
d250443
25a1345
 
 
 
fbd53e3
25a1345
 
 
 
 
9bf5ef6
 
 
 
25a1345
9bf5ef6
25a1345
9bf5ef6
25a1345
 
c0f0687
fbd53e3
25a1345
 
 
 
fbd53e3
 
25a1345
 
 
 
 
 
fbd53e3
25a1345
 
 
fbd53e3
 
 
25a1345
fbd53e3
 
 
 
25a1345
 
c1e3f7e
25a1345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d250443
25a1345
560834a
25a1345
fbd53e3
 
 
25a1345
 
d250443
25a1345
 
 
 
 
 
d250443
 
 
 
 
 
 
 
 
 
 
25a1345
 
d250443
25a1345
d250443
25a1345
d250443
25a1345
312b8c7
25a1345
c1e3f7e
25a1345
 
 
8e6ca7f
 
311bb5a
 
8e6ca7f
 
312b8c7
8e6ca7f
312b8c7
 
 
 
 
8e6ca7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311bb5a
8e6ca7f
 
 
 
 
25a1345
 
 
 
c0f0687
d250443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
# gradio app.py --watch-dirs app.py

import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import tempfile
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torchaudio
import torchaudio.transforms as T
from matplotlib.patches import Circle
from stable_baselines3 import SAC
from warehouse_env import WarehouseEnv
from types import SimpleNamespace


# ---------------------------- #
# global variables
# ---------------------------- #
# models
# a model for the automatic-speech-recognition task
# asr_pipe_default = pipeline("automatic-speech-recognition")
save_dir = './models_for_proj/wav2vec2-base-960h'
model = Wav2Vec2ForCTC.from_pretrained(save_dir)
processor = Wav2Vec2Processor.from_pretrained(save_dir)


# env variables
rl_model_name = 'agent_policies/sac_warehouse_r_10_working_v1.zip'
# agent_pos = {'x': 50.0, 'y': 50.0}
agent_pos = SimpleNamespace(**{'x': 50.0, 'y': 50.0})
goal_dict = {
    '1': (20, 20),
    '2': (80, 20),
    '3': (80, 80),
    '4': (20, 80),
}
targets_x, targets_y = [], []
for k, v in goal_dict.items():
    targets_x.append(v[0])
    targets_y.append(v[1])
r_coverage = 10


custom_css = """
#mytextbox textarea {
    color: blue;
    background-color: #f0f0f0;
    font-weight: bold;
}
"""

# ---------------------------- #
# functions
# ---------------------------- #
def create_standing_animation():
    path = [(agent_pos.x, agent_pos.y)]
    return create_animation(path)


def create_animation(path):
    # path = [(i,i) for i in range(90)]
    # targets_x = [20, 80, 80, 20]
    # targets_y = [20, 20, 80, 80]
    # RADIUS_COVERAGE = 10
    fig, ax = plt.subplots(figsize=(7, 7))

    # agent
    ln1, = plt.plot([path[0][0]], [path[0][1]], marker='o', color='b', alpha=0.5, linewidth=5, markersize=15)

    # targets
    targets_x, targets_y = [], []
    for k, v in goal_dict.items():
        targets_x.append(v[0])
        targets_y.append(v[1])
    ln2, = plt.plot(targets_x, targets_y, marker='X', color='orange', alpha=0.5, linestyle='none', markersize=15)
    for t_name, (t_x, t_y) in goal_dict.items():
        circle = Circle((t_x, t_y), r_coverage, color='orange', fill=True, alpha=0.3)
        plt.text(t_x, t_y, f"{t_name}", fontsize=20, color='k')
        ax.add_patch(circle)
    # plt.tight_layout()

    def init():
        ax.set_xlim([0, 100])
        ax.set_ylim([0, 100])
        ax.set_title(f'Warehouse Env', fontweight="bold", size=10)
        return ln1,

    def update(frame):
        # for each frame, update the data stored on each artist.
        x = [path[frame][0]]
        y = [path[frame][1]]

        ln1.set_data(x, y)
        return ln1,

    ani = animation.FuncAnimation(fig, update, frames=len(path),
                                  init_func=init, blit=True, repeat=False)
    # plt.show()

    # Save to MP4
    temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
    ani.save(temp_video.name, writer='ffmpeg', fps=30)
    plt.close(fig)
    return temp_video.name


def move_agent(target_input: int):
    if target_input not in goal_dict:
        return create_standing_animation(), 'Say it again.. To what goal you want me to go?..'
    # get goal locations:
    goal_x, goal_y = goal_dict[target_input]
    # build the path
    env: WarehouseEnv = WarehouseEnv(render_mode='')
    model = SAC.load(rl_model_name)
    obs, info = env.reset(agent_x=agent_pos.x, agent_y=agent_pos.y, goal_x=goal_x, goal_y=goal_y)
    path = []
    while True:
        action, _ = model.predict(obs)
        obs, rewards, done, trunc, info = env.step(action)
        path.append((env.agent_x, env.agent_y))
        if done:
            break
        if trunc:
            obs, info = env.reset(agent_x=agent_pos.x, agent_y=agent_pos.y, goal_x=goal_x, goal_y=goal_y)
            path = []

    agent_pos.x = path[-1][0]
    agent_pos.y = path[-1][1]
    # create animation
    video_output = create_animation(path)

    return video_output, f'Got it! I went to the goal number {target_input}.'


def load_image_on_start():
    return np.random.rand(700, 700)

def get_text_request(audio_input):
    # --------------------------------------------------------------------------- #
    audio_input_sr, audio_input_np = audio_input
    audio_input_t = torch.tensor(audio_input_np, dtype=torch.float32)
    target_sr = 16000
    resampler = T.Resample(audio_input_sr, target_sr, dtype=audio_input_t.dtype)
    resampled_audio_input_t: torch.Tensor = resampler(audio_input_t)
    resampled_audio_input_np = resampled_audio_input_t.numpy()
    # --------------------------------------------------------------------------- #
    # result = asr_pipe_default(resampled_audio_input_np)
    inputs = processor(resampled_audio_input_np, sampling_rate=16000, return_tensors="pt", padding=True)
    # Inference
    with torch.no_grad():
        logits = model(**inputs).logits
    # Decode
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.decode(predicted_ids[0])
    # print("Transcription:", transcription)
    return transcription

def get_target_from_request(request_text):
    if any(item in request_text for item in ['ONE', 'FIRST']):
        return 1
    if any(item in request_text for item in ['TWO', 'SECOND']):
        return 2
    if any(item in request_text for item in ['THREE', 'THIRD']):
        return 3
    if any(item in request_text for item in ['FOUR', 'FOURTH', 'FOR']):
        return 4
    return 'No goal found.'


    # main blocks
with gr.Blocks(css=custom_css) as demo:
    gr.Markdown("# Agent Control with Language")
    gr.Markdown('Say the agent where to go')
    # gr.Markdown('Say the agent where to go and what to do')

    with gr.Row():
        with gr.Column(scale=12):
            status = gr.Textbox(label='Status:', lines=2, elem_id="mytextbox")
            request_audio = gr.Microphone(editable=False, scale=2)
            # send_btn = gr.Button(value='Send Request')
            request_text = gr.Textbox(label="Request:", lines=2, interactive=False, visible=False)
            request_target = gr.Textbox(label='Target:', lines=2, visible=False)
        with gr.Column(scale=8):
            output_env = gr.Video(label="Env:", autoplay=True)
    with gr.Accordion("TODO List", open=False):
        gr.Markdown("""
        ## PLAN
        - [x] to use audio as an input for requests
        - [x] to learn a policy for navigation from location to location
        - [x] to build an interface that will show the status of the request
        - [ ] to incorporate a longer chain of goals; for example, go there and pick the package, then come back
        - [ ] to introduce additional learnt capabilities
        - [ ] to build more complex environments where the movement is not so straightforward
        """)

    # EVENTS:
    # gr.on(triggers=["load"], fn=load_image_on_start, outputs=output_env_image)
    # my_demo.load(fn=load_image_on_start, outputs=output_env_image)
    demo.load(fn=create_standing_animation, outputs=output_env)
    demo.load(fn=lambda: 'To which target do you want me to go?', outputs=status)
    # request_audio.stream(fn=get_text_request, inputs=request_audio, outputs=request_text)
    request_audio.stop_recording(fn=get_text_request, inputs=request_audio, outputs=request_text)
    request_text.change(fn=get_target_from_request, inputs=request_text, outputs=request_target)
    request_target.change(fn=move_agent, inputs=request_target, outputs=[output_env, status])
    request_audio.stop_recording(lambda: None, outputs=request_audio)

# ---------------------------- #
# main
# ---------------------------- #
demo.launch()



# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# model_id = "./models_for_proj/librispeech_asr_dummy"
# model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
# model.to(device)
# processor = AutoProcessor.from_pretrained(model_id)
# asr_pipe = pipeline(
#     "automatic-speech-recognition",
#     model=model,
#     tokenizer=processor.tokenizer,
#     feature_extractor=processor.feature_extractor,
#     max_new_tokens=128,
#     torch_dtype=torch_dtype,
#     device=device,
# )