# gradio app.py --watch-dirs app.py import gradio as gr import numpy as np import matplotlib.pyplot as plt import matplotlib.animation as animation import tempfile import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor import torchaudio import torchaudio.transforms as T from matplotlib.patches import Circle from stable_baselines3 import SAC from warehouse_env import WarehouseEnv from types import SimpleNamespace # ---------------------------- # # global variables # ---------------------------- # # models # a model for the automatic-speech-recognition task # asr_pipe_default = pipeline("automatic-speech-recognition") save_dir = './models_for_proj/wav2vec2-base-960h' model = Wav2Vec2ForCTC.from_pretrained(save_dir) processor = Wav2Vec2Processor.from_pretrained(save_dir) # env variables rl_model_name = 'agent_policies/sac_warehouse_r_10_working_v1.zip' # agent_pos = {'x': 50.0, 'y': 50.0} agent_pos = SimpleNamespace(**{'x': 50.0, 'y': 50.0}) goal_dict = { '1': (20, 20), '2': (80, 20), '3': (80, 80), '4': (20, 80), } targets_x, targets_y = [], [] for k, v in goal_dict.items(): targets_x.append(v[0]) targets_y.append(v[1]) r_coverage = 10 custom_css = """ #mytextbox textarea { color: blue; background-color: #f0f0f0; font-weight: bold; } """ # ---------------------------- # # functions # ---------------------------- # def create_standing_animation(): path = [(agent_pos.x, agent_pos.y)] return create_animation(path) def create_animation(path): # path = [(i,i) for i in range(90)] # targets_x = [20, 80, 80, 20] # targets_y = [20, 20, 80, 80] # RADIUS_COVERAGE = 10 fig, ax = plt.subplots(figsize=(7, 7)) # agent ln1, = plt.plot([path[0][0]], [path[0][1]], marker='o', color='b', alpha=0.5, linewidth=5, markersize=15) # targets targets_x, targets_y = [], [] for k, v in goal_dict.items(): targets_x.append(v[0]) targets_y.append(v[1]) ln2, = plt.plot(targets_x, targets_y, marker='X', color='orange', alpha=0.5, linestyle='none', markersize=15) for t_name, (t_x, t_y) in goal_dict.items(): circle = Circle((t_x, t_y), r_coverage, color='orange', fill=True, alpha=0.3) plt.text(t_x, t_y, f"{t_name}", fontsize=20, color='k') ax.add_patch(circle) # plt.tight_layout() def init(): ax.set_xlim([0, 100]) ax.set_ylim([0, 100]) ax.set_title(f'Warehouse Env', fontweight="bold", size=10) return ln1, def update(frame): # for each frame, update the data stored on each artist. x = [path[frame][0]] y = [path[frame][1]] ln1.set_data(x, y) return ln1, ani = animation.FuncAnimation(fig, update, frames=len(path), init_func=init, blit=True, repeat=False) # plt.show() # Save to MP4 temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") ani.save(temp_video.name, writer='ffmpeg', fps=30) plt.close(fig) return temp_video.name def move_agent(target_input: int): if target_input not in goal_dict: return create_standing_animation(), 'Say it again.. To what goal you want me to go?..' # get goal locations: goal_x, goal_y = goal_dict[target_input] # build the path env: WarehouseEnv = WarehouseEnv(render_mode='') model = SAC.load(rl_model_name) obs, info = env.reset(agent_x=agent_pos.x, agent_y=agent_pos.y, goal_x=goal_x, goal_y=goal_y) path = [] while True: action, _ = model.predict(obs) obs, rewards, done, trunc, info = env.step(action) path.append((env.agent_x, env.agent_y)) if done: break if trunc: obs, info = env.reset(agent_x=agent_pos.x, agent_y=agent_pos.y, goal_x=goal_x, goal_y=goal_y) path = [] agent_pos.x = path[-1][0] agent_pos.y = path[-1][1] # create animation video_output = create_animation(path) return video_output, f'Got it! I went to the goal number {target_input}.' def load_image_on_start(): return np.random.rand(700, 700) def get_text_request(audio_input): # --------------------------------------------------------------------------- # audio_input_sr, audio_input_np = audio_input audio_input_t = torch.tensor(audio_input_np, dtype=torch.float32) target_sr = 16000 resampler = T.Resample(audio_input_sr, target_sr, dtype=audio_input_t.dtype) resampled_audio_input_t: torch.Tensor = resampler(audio_input_t) resampled_audio_input_np = resampled_audio_input_t.numpy() # --------------------------------------------------------------------------- # # result = asr_pipe_default(resampled_audio_input_np) inputs = processor(resampled_audio_input_np, sampling_rate=16000, return_tensors="pt", padding=True) # Inference with torch.no_grad(): logits = model(**inputs).logits # Decode predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.decode(predicted_ids[0]) # print("Transcription:", transcription) return transcription def get_target_from_request(request_text): if any(item in request_text for item in ['ONE', 'FIRST']): return 1 if any(item in request_text for item in ['TWO', 'SECOND']): return 2 if any(item in request_text for item in ['THREE', 'THIRD']): return 3 if any(item in request_text for item in ['FOUR', 'FOURTH', 'FOR']): return 4 return 'No goal found.' # main blocks with gr.Blocks(css=custom_css) as demo: gr.Markdown("# Agent Control with Language") gr.Markdown('Say the agent where to go') # gr.Markdown('Say the agent where to go and what to do') with gr.Row(): with gr.Column(scale=12): status = gr.Textbox(label='Status:', lines=2, elem_id="mytextbox") request_audio = gr.Microphone(editable=False, scale=2) # send_btn = gr.Button(value='Send Request') request_text = gr.Textbox(label="Request:", lines=2, interactive=False, visible=False) request_target = gr.Textbox(label='Target:', lines=2, visible=False) with gr.Column(scale=8): output_env = gr.Video(label="Env:", autoplay=True) with gr.Accordion("TODO List", open=False): gr.Markdown(""" ## PLAN - [x] to use audio as an input for requests - [x] to learn a policy for navigation from location to location - [x] to build an interface that will show the status of the request - [ ] to incorporate a longer chain of goals; for example, go there and pick the package, then come back - [ ] to introduce additional learnt capabilities - [ ] to build more complex environments where the movement is not so straightforward """) # EVENTS: # gr.on(triggers=["load"], fn=load_image_on_start, outputs=output_env_image) # my_demo.load(fn=load_image_on_start, outputs=output_env_image) demo.load(fn=create_standing_animation, outputs=output_env) demo.load(fn=lambda: 'To which target do you want me to go?', outputs=status) # request_audio.stream(fn=get_text_request, inputs=request_audio, outputs=request_text) request_audio.stop_recording(fn=get_text_request, inputs=request_audio, outputs=request_text) request_text.change(fn=get_target_from_request, inputs=request_text, outputs=request_target) request_target.change(fn=move_agent, inputs=request_target, outputs=[output_env, status]) request_audio.stop_recording(lambda: None, outputs=request_audio) # ---------------------------- # # main # ---------------------------- # demo.launch() # device = "cuda:0" if torch.cuda.is_available() else "cpu" # torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 # model_id = "./models_for_proj/librispeech_asr_dummy" # model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True) # model.to(device) # processor = AutoProcessor.from_pretrained(model_id) # asr_pipe = pipeline( # "automatic-speech-recognition", # model=model, # tokenizer=processor.tokenizer, # feature_extractor=processor.feature_extractor, # max_new_tokens=128, # torch_dtype=torch_dtype, # device=device, # )