ArseniyPerchik's picture
more
311bb5a
# gradio app.py --watch-dirs app.py
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import tempfile
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torchaudio
import torchaudio.transforms as T
from matplotlib.patches import Circle
from stable_baselines3 import SAC
from warehouse_env import WarehouseEnv
from types import SimpleNamespace
# ---------------------------- #
# global variables
# ---------------------------- #
# models
# a model for the automatic-speech-recognition task
# asr_pipe_default = pipeline("automatic-speech-recognition")
save_dir = './models_for_proj/wav2vec2-base-960h'
model = Wav2Vec2ForCTC.from_pretrained(save_dir)
processor = Wav2Vec2Processor.from_pretrained(save_dir)
# env variables
rl_model_name = 'agent_policies/sac_warehouse_r_10_working_v1.zip'
# agent_pos = {'x': 50.0, 'y': 50.0}
agent_pos = SimpleNamespace(**{'x': 50.0, 'y': 50.0})
goal_dict = {
'1': (20, 20),
'2': (80, 20),
'3': (80, 80),
'4': (20, 80),
}
targets_x, targets_y = [], []
for k, v in goal_dict.items():
targets_x.append(v[0])
targets_y.append(v[1])
r_coverage = 10
custom_css = """
#mytextbox textarea {
color: blue;
background-color: #f0f0f0;
font-weight: bold;
}
"""
# ---------------------------- #
# functions
# ---------------------------- #
def create_standing_animation():
path = [(agent_pos.x, agent_pos.y)]
return create_animation(path)
def create_animation(path):
# path = [(i,i) for i in range(90)]
# targets_x = [20, 80, 80, 20]
# targets_y = [20, 20, 80, 80]
# RADIUS_COVERAGE = 10
fig, ax = plt.subplots(figsize=(7, 7))
# agent
ln1, = plt.plot([path[0][0]], [path[0][1]], marker='o', color='b', alpha=0.5, linewidth=5, markersize=15)
# targets
targets_x, targets_y = [], []
for k, v in goal_dict.items():
targets_x.append(v[0])
targets_y.append(v[1])
ln2, = plt.plot(targets_x, targets_y, marker='X', color='orange', alpha=0.5, linestyle='none', markersize=15)
for t_name, (t_x, t_y) in goal_dict.items():
circle = Circle((t_x, t_y), r_coverage, color='orange', fill=True, alpha=0.3)
plt.text(t_x, t_y, f"{t_name}", fontsize=20, color='k')
ax.add_patch(circle)
# plt.tight_layout()
def init():
ax.set_xlim([0, 100])
ax.set_ylim([0, 100])
ax.set_title(f'Warehouse Env', fontweight="bold", size=10)
return ln1,
def update(frame):
# for each frame, update the data stored on each artist.
x = [path[frame][0]]
y = [path[frame][1]]
ln1.set_data(x, y)
return ln1,
ani = animation.FuncAnimation(fig, update, frames=len(path),
init_func=init, blit=True, repeat=False)
# plt.show()
# Save to MP4
temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
ani.save(temp_video.name, writer='ffmpeg', fps=30)
plt.close(fig)
return temp_video.name
def move_agent(target_input: int):
if target_input not in goal_dict:
return create_standing_animation(), 'Say it again.. To what goal you want me to go?..'
# get goal locations:
goal_x, goal_y = goal_dict[target_input]
# build the path
env: WarehouseEnv = WarehouseEnv(render_mode='')
model = SAC.load(rl_model_name)
obs, info = env.reset(agent_x=agent_pos.x, agent_y=agent_pos.y, goal_x=goal_x, goal_y=goal_y)
path = []
while True:
action, _ = model.predict(obs)
obs, rewards, done, trunc, info = env.step(action)
path.append((env.agent_x, env.agent_y))
if done:
break
if trunc:
obs, info = env.reset(agent_x=agent_pos.x, agent_y=agent_pos.y, goal_x=goal_x, goal_y=goal_y)
path = []
agent_pos.x = path[-1][0]
agent_pos.y = path[-1][1]
# create animation
video_output = create_animation(path)
return video_output, f'Got it! I went to the goal number {target_input}.'
def load_image_on_start():
return np.random.rand(700, 700)
def get_text_request(audio_input):
# --------------------------------------------------------------------------- #
audio_input_sr, audio_input_np = audio_input
audio_input_t = torch.tensor(audio_input_np, dtype=torch.float32)
target_sr = 16000
resampler = T.Resample(audio_input_sr, target_sr, dtype=audio_input_t.dtype)
resampled_audio_input_t: torch.Tensor = resampler(audio_input_t)
resampled_audio_input_np = resampled_audio_input_t.numpy()
# --------------------------------------------------------------------------- #
# result = asr_pipe_default(resampled_audio_input_np)
inputs = processor(resampled_audio_input_np, sampling_rate=16000, return_tensors="pt", padding=True)
# Inference
with torch.no_grad():
logits = model(**inputs).logits
# Decode
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.decode(predicted_ids[0])
# print("Transcription:", transcription)
return transcription
def get_target_from_request(request_text):
if any(item in request_text for item in ['ONE', 'FIRST']):
return 1
if any(item in request_text for item in ['TWO', 'SECOND']):
return 2
if any(item in request_text for item in ['THREE', 'THIRD']):
return 3
if any(item in request_text for item in ['FOUR', 'FOURTH', 'FOR']):
return 4
return 'No goal found.'
# main blocks
with gr.Blocks(css=custom_css) as demo:
gr.Markdown("# Agent Control with Language")
gr.Markdown('Say the agent where to go')
# gr.Markdown('Say the agent where to go and what to do')
with gr.Row():
with gr.Column(scale=12):
status = gr.Textbox(label='Status:', lines=2, elem_id="mytextbox")
request_audio = gr.Microphone(editable=False, scale=2)
# send_btn = gr.Button(value='Send Request')
request_text = gr.Textbox(label="Request:", lines=2, interactive=False, visible=False)
request_target = gr.Textbox(label='Target:', lines=2, visible=False)
with gr.Column(scale=8):
output_env = gr.Video(label="Env:", autoplay=True)
with gr.Accordion("TODO List", open=False):
gr.Markdown("""
## PLAN
- [x] to use audio as an input for requests
- [x] to learn a policy for navigation from location to location
- [x] to build an interface that will show the status of the request
- [ ] to incorporate a longer chain of goals; for example, go there and pick the package, then come back
- [ ] to introduce additional learnt capabilities
- [ ] to build more complex environments where the movement is not so straightforward
""")
# EVENTS:
# gr.on(triggers=["load"], fn=load_image_on_start, outputs=output_env_image)
# my_demo.load(fn=load_image_on_start, outputs=output_env_image)
demo.load(fn=create_standing_animation, outputs=output_env)
demo.load(fn=lambda: 'To which target do you want me to go?', outputs=status)
# request_audio.stream(fn=get_text_request, inputs=request_audio, outputs=request_text)
request_audio.stop_recording(fn=get_text_request, inputs=request_audio, outputs=request_text)
request_text.change(fn=get_target_from_request, inputs=request_text, outputs=request_target)
request_target.change(fn=move_agent, inputs=request_target, outputs=[output_env, status])
request_audio.stop_recording(lambda: None, outputs=request_audio)
# ---------------------------- #
# main
# ---------------------------- #
demo.launch()
# device = "cuda:0" if torch.cuda.is_available() else "cpu"
# torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# model_id = "./models_for_proj/librispeech_asr_dummy"
# model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True)
# model.to(device)
# processor = AutoProcessor.from_pretrained(model_id)
# asr_pipe = pipeline(
# "automatic-speech-recognition",
# model=model,
# tokenizer=processor.tokenizer,
# feature_extractor=processor.feature_extractor,
# max_new_tokens=128,
# torch_dtype=torch_dtype,
# device=device,
# )