Spaces:
Build error
Build error
File size: 8,426 Bytes
25a1345 c0f0687 fbd53e3 25a1345 d250443 25a1345 d250443 25a1345 c0f0687 560834a 25a1345 d250443 25a1345 d250443 25a1345 fbd53e3 25a1345 9bf5ef6 25a1345 9bf5ef6 25a1345 9bf5ef6 25a1345 c0f0687 fbd53e3 25a1345 fbd53e3 25a1345 fbd53e3 25a1345 fbd53e3 25a1345 fbd53e3 25a1345 c1e3f7e 25a1345 d250443 25a1345 560834a 25a1345 fbd53e3 25a1345 d250443 25a1345 d250443 25a1345 d250443 25a1345 d250443 25a1345 d250443 25a1345 312b8c7 25a1345 c1e3f7e 25a1345 8e6ca7f 311bb5a 8e6ca7f 312b8c7 8e6ca7f 312b8c7 8e6ca7f 311bb5a 8e6ca7f 25a1345 c0f0687 d250443 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
# gradio app.py --watch-dirs app.py
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import tempfile
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torchaudio
import torchaudio.transforms as T
from matplotlib.patches import Circle
from stable_baselines3 import SAC
from warehouse_env import WarehouseEnv
from types import SimpleNamespace
# ---------------------------- #
# global variables
# ---------------------------- #
# models
# a model for the automatic-speech-recognition task
# asr_pipe_default = pipeline("automatic-speech-recognition")
save_dir = './models_for_proj/wav2vec2-base-960h'
model = Wav2Vec2ForCTC.from_pretrained(save_dir)
processor = Wav2Vec2Processor.from_pretrained(save_dir)
# env variables
rl_model_name = 'agent_policies/sac_warehouse_r_10_working_v1.zip'
# agent_pos = {'x': 50.0, 'y': 50.0}
agent_pos = SimpleNamespace(**{'x': 50.0, 'y': 50.0})
goal_dict = {
'1': (20, 20),
'2': (80, 20),
'3': (80, 80),
'4': (20, 80),
}
targets_x, targets_y = [], []
for k, v in goal_dict.items():
targets_x.append(v[0])
targets_y.append(v[1])
r_coverage = 10
custom_css = """
#mytextbox textarea {
color: blue;
background-color: #f0f0f0;
font-weight: bold;
}
"""
# ---------------------------- #
# functions
# ---------------------------- #
def create_standing_animation():
path = [(agent_pos.x, agent_pos.y)]
return create_animation(path)
def create_animation(path):
# path = [(i,i) for i in range(90)]
# targets_x = [20, 80, 80, 20]
# targets_y = [20, 20, 80, 80]
# RADIUS_COVERAGE = 10
fig, ax = plt.subplots(figsize=(7, 7))
# agent
ln1, = plt.plot([path[0][0]], [path[0][1]], marker='o', color='b', alpha=0.5, linewidth=5, markersize=15)
# targets
targets_x, targets_y = [], []
for k, v in goal_dict.items():
targets_x.append(v[0])
targets_y.append(v[1])
ln2, = plt.plot(targets_x, targets_y, marker='X', color='orange', alpha=0.5, linestyle='none', markersize=15)
for t_name, (t_x, t_y) in goal_dict.items():
circle = Circle((t_x, t_y), r_coverage, color='orange', fill=True, alpha=0.3)
plt.text(t_x, t_y, f"{t_name}", fontsize=20, color='k')
ax.add_patch(circle)
# plt.tight_layout()
def init():
ax.set_xlim([0, 100])
ax.set_ylim([0, 100])
ax.set_title(f'Warehouse Env', fontweight="bold", size=10)
return ln1,
def update(frame):
# for each frame, update the data stored on each artist.
x = [path[frame][0]]
y = [path[frame][1]]
ln1.set_data(x, y)
return ln1,
ani = animation.FuncAnimation(fig, update, frames=len(path),
init_func=init, blit=True, repeat=False)
# plt.show()
# Save to MP4
temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
ani.save(temp_video.name, writer='ffmpeg', fps=30)
plt.close(fig)
return temp_video.name
def move_agent(target_input: int):
if target_input not in goal_dict:
return create_standing_animation(), 'Say it again.. To what goal you want me to go?..'
# get goal locations:
goal_x, goal_y = goal_dict[target_input]
# build the path
env: WarehouseEnv = WarehouseEnv(render_mode='')
model = SAC.load(rl_model_name)
obs, info = env.reset(agent_x=agent_pos.x, agent_y=agent_pos.y, goal_x=goal_x, goal_y=goal_y)
path = []
while True:
action, _ = model.predict(obs)
obs, rewards, done, trunc, info = env.step(action)
path.append((env.agent_x, env.agent_y))
if done:
break
if trunc:
obs, info = env.reset(agent_x=agent_pos.x, agent_y=agent_pos.y, goal_x=goal_x, goal_y=goal_y)
path = []
agent_pos.x = path[-1][0]
agent_pos.y = path[-1][1]
# create animation
video_output = create_animation(path)
return video_output, f'Got it! I went to the goal number {target_input}.'
def load_image_on_start():
return np.random.rand(700, 700)
def get_text_request(audio_input):
# --------------------------------------------------------------------------- #
audio_input_sr, audio_input_np = audio_input
audio_input_t = torch.tensor(audio_input_np, dtype=torch.float32)
target_sr = 16000
resampler = T.Resample(audio_input_sr, target_sr, dtype=audio_input_t.dtype)
resampled_audio_input_t: torch.Tensor = resampler(audio_input_t)
resampled_audio_input_np = resampled_audio_input_t.numpy()
# --------------------------------------------------------------------------- #
# result = asr_pipe_default(resampled_audio_input_np)
inputs = processor(resampled_audio_input_np, sampling_rate=16000, return_tensors="pt", padding=True)
# Inference
with torch.no_grad():
logits = model(**inputs).logits
# Decode
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
# print("Transcription:", transcription)
return transcription
def get_target_from_request(request_text):
if any(item in request_text for item in ['ONE', 'FIRST']):
return 1
if any(item in request_text for item in ['TWO', 'SECOND']):
return 2
if any(item in request_text for item in ['THREE', 'THIRD']):
return 3
if any(item in request_text for item in ['FOUR', 'FOURTH', 'FOR']):
return 4
return 'No goal found.'
# main blocks
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("# Agent Control with Language")
gr.Markdown('Say the agent where to go')
# gr.Markdown('Say the agent where to go and what to do')
with gr.Row():
with gr.Column(scale=12):
status = gr.Textbox(label='Status:', lines=2, elem_id="mytextbox")
request_audio = gr.Microphone(editable=False, scale=2)
# send_btn = gr.Button(value='Send Request')
request_text = gr.Textbox(label="Request:", lines=2, interactive=False, visible=False)
request_target = gr.Textbox(label='Target:', lines=2, visible=False)
with gr.Column(scale=8):
output_env = gr.Video(label="Env:", autoplay=True)
with gr.Accordion("TODO List", open=False):
gr.Markdown("""
## PLAN
- [x] to use audio as an input for requests
- [x] to learn a policy for navigation from location to location
- [x] to build an interface that will show the status of the request
- [ ] to incorporate a longer chain of goals; for example, go there and pick the package, then come back
- [ ] to introduce additional learnt capabilities
- [ ] to build more complex environments where the movement is not so straightforward
""")
# EVENTS:
# gr.on(triggers=["load"], fn=load_image_on_start, outputs=output_env_image)
# my_demo.load(fn=load_image_on_start, outputs=output_env_image)
demo.load(fn=create_standing_animation, outputs=output_env)
demo.load(fn=lambda: 'To which target do you want me to go?', outputs=status)
# request_audio.stream(fn=get_text_request, inputs=request_audio, outputs=request_text)
request_audio.stop_recording(fn=get_text_request, inputs=request_audio, outputs=request_text)
request_text.change(fn=get_target_from_request, inputs=request_text, outputs=request_target)
request_target.change(fn=move_agent, inputs=request_target, outputs=[output_env, status])
request_audio.stop_recording(lambda: None, outputs=request_audio)
# ---------------------------- #
# main
# ---------------------------- #
demo.launch()
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# model_id = "./models_for_proj/librispeech_asr_dummy"
# model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
# model.to(device)
# processor = AutoProcessor.from_pretrained(model_id)
# asr_pipe = pipeline(
# "automatic-speech-recognition",
# model=model,
# tokenizer=processor.tokenizer,
# feature_extractor=processor.feature_extractor,
# max_new_tokens=128,
# torch_dtype=torch_dtype,
# device=device,
# )
|