Spaces:
Build error
Build error
| # gradio app.py --watch-dirs app.py | |
| import gradio as gr | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import matplotlib.animation as animation | |
| import tempfile | |
| import torch | |
| from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline | |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor | |
| import torchaudio | |
| import torchaudio.transforms as T | |
| from matplotlib.patches import Circle | |
| from stable_baselines3 import SAC | |
| from warehouse_env import WarehouseEnv | |
| from types import SimpleNamespace | |
| # ---------------------------- # | |
| # global variables | |
| # ---------------------------- # | |
| # models | |
| # a model for the automatic-speech-recognition task | |
| # asr_pipe_default = pipeline("automatic-speech-recognition") | |
| save_dir = './models_for_proj/wav2vec2-base-960h' | |
| model = Wav2Vec2ForCTC.from_pretrained(save_dir) | |
| processor = Wav2Vec2Processor.from_pretrained(save_dir) | |
| # env variables | |
| rl_model_name = 'agent_policies/sac_warehouse_r_10_working_v1.zip' | |
| # agent_pos = {'x': 50.0, 'y': 50.0} | |
| agent_pos = SimpleNamespace(**{'x': 50.0, 'y': 50.0}) | |
| goal_dict = { | |
| '1': (20, 20), | |
| '2': (80, 20), | |
| '3': (80, 80), | |
| '4': (20, 80), | |
| } | |
| targets_x, targets_y = [], [] | |
| for k, v in goal_dict.items(): | |
| targets_x.append(v[0]) | |
| targets_y.append(v[1]) | |
| r_coverage = 10 | |
| custom_css = """ | |
| #mytextbox textarea { | |
| color: blue; | |
| background-color: #f0f0f0; | |
| font-weight: bold; | |
| } | |
| """ | |
| # ---------------------------- # | |
| # functions | |
| # ---------------------------- # | |
| def create_standing_animation(): | |
| path = [(agent_pos.x, agent_pos.y)] | |
| return create_animation(path) | |
| def create_animation(path): | |
| # path = [(i,i) for i in range(90)] | |
| # targets_x = [20, 80, 80, 20] | |
| # targets_y = [20, 20, 80, 80] | |
| # RADIUS_COVERAGE = 10 | |
| fig, ax = plt.subplots(figsize=(7, 7)) | |
| # agent | |
| ln1, = plt.plot([path[0][0]], [path[0][1]], marker='o', color='b', alpha=0.5, linewidth=5, markersize=15) | |
| # targets | |
| targets_x, targets_y = [], [] | |
| for k, v in goal_dict.items(): | |
| targets_x.append(v[0]) | |
| targets_y.append(v[1]) | |
| ln2, = plt.plot(targets_x, targets_y, marker='X', color='orange', alpha=0.5, linestyle='none', markersize=15) | |
| for t_name, (t_x, t_y) in goal_dict.items(): | |
| circle = Circle((t_x, t_y), r_coverage, color='orange', fill=True, alpha=0.3) | |
| plt.text(t_x, t_y, f"{t_name}", fontsize=20, color='k') | |
| ax.add_patch(circle) | |
| # plt.tight_layout() | |
| def init(): | |
| ax.set_xlim([0, 100]) | |
| ax.set_ylim([0, 100]) | |
| ax.set_title(f'Warehouse Env', fontweight="bold", size=10) | |
| return ln1, | |
| def update(frame): | |
| # for each frame, update the data stored on each artist. | |
| x = [path[frame][0]] | |
| y = [path[frame][1]] | |
| ln1.set_data(x, y) | |
| return ln1, | |
| ani = animation.FuncAnimation(fig, update, frames=len(path), | |
| init_func=init, blit=True, repeat=False) | |
| # plt.show() | |
| # Save to MP4 | |
| temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| ani.save(temp_video.name, writer='ffmpeg', fps=30) | |
| plt.close(fig) | |
| return temp_video.name | |
| def move_agent(target_input: int): | |
| if target_input not in goal_dict: | |
| return create_standing_animation(), 'Say it again.. To what goal you want me to go?..' | |
| # get goal locations: | |
| goal_x, goal_y = goal_dict[target_input] | |
| # build the path | |
| env: WarehouseEnv = WarehouseEnv(render_mode='') | |
| model = SAC.load(rl_model_name) | |
| obs, info = env.reset(agent_x=agent_pos.x, agent_y=agent_pos.y, goal_x=goal_x, goal_y=goal_y) | |
| path = [] | |
| while True: | |
| action, _ = model.predict(obs) | |
| obs, rewards, done, trunc, info = env.step(action) | |
| path.append((env.agent_x, env.agent_y)) | |
| if done: | |
| break | |
| if trunc: | |
| obs, info = env.reset(agent_x=agent_pos.x, agent_y=agent_pos.y, goal_x=goal_x, goal_y=goal_y) | |
| path = [] | |
| agent_pos.x = path[-1][0] | |
| agent_pos.y = path[-1][1] | |
| # create animation | |
| video_output = create_animation(path) | |
| return video_output, f'Got it! I went to the goal number {target_input}.' | |
| def load_image_on_start(): | |
| return np.random.rand(700, 700) | |
| def get_text_request(audio_input): | |
| # --------------------------------------------------------------------------- # | |
| audio_input_sr, audio_input_np = audio_input | |
| audio_input_t = torch.tensor(audio_input_np, dtype=torch.float32) | |
| target_sr = 16000 | |
| resampler = T.Resample(audio_input_sr, target_sr, dtype=audio_input_t.dtype) | |
| resampled_audio_input_t: torch.Tensor = resampler(audio_input_t) | |
| resampled_audio_input_np = resampled_audio_input_t.numpy() | |
| # --------------------------------------------------------------------------- # | |
| # result = asr_pipe_default(resampled_audio_input_np) | |
| inputs = processor(resampled_audio_input_np, sampling_rate=16000, return_tensors="pt", padding=True) | |
| # Inference | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| # Decode | |
| predicted_ids = torch.argmax(logits, dim=-1) | |
| transcription = processor.decode(predicted_ids[0]) | |
| # print("Transcription:", transcription) | |
| return transcription | |
| def get_target_from_request(request_text): | |
| if any(item in request_text for item in ['ONE', 'FIRST']): | |
| return 1 | |
| if any(item in request_text for item in ['TWO', 'SECOND']): | |
| return 2 | |
| if any(item in request_text for item in ['THREE', 'THIRD']): | |
| return 3 | |
| if any(item in request_text for item in ['FOUR', 'FOURTH', 'FOR']): | |
| return 4 | |
| return 'No goal found.' | |
| # main blocks | |
| with gr.Blocks(css=custom_css) as demo: | |
| gr.Markdown("# Agent Control with Language") | |
| gr.Markdown('Say the agent where to go') | |
| # gr.Markdown('Say the agent where to go and what to do') | |
| with gr.Row(): | |
| with gr.Column(scale=12): | |
| status = gr.Textbox(label='Status:', lines=2, elem_id="mytextbox") | |
| request_audio = gr.Microphone(editable=False, scale=2) | |
| # send_btn = gr.Button(value='Send Request') | |
| request_text = gr.Textbox(label="Request:", lines=2, interactive=False, visible=False) | |
| request_target = gr.Textbox(label='Target:', lines=2, visible=False) | |
| with gr.Column(scale=8): | |
| output_env = gr.Video(label="Env:", autoplay=True) | |
| with gr.Accordion("TODO List", open=False): | |
| gr.Markdown(""" | |
| ## PLAN | |
| - [x] to use audio as an input for requests | |
| - [x] to learn a policy for navigation from location to location | |
| - [x] to build an interface that will show the status of the request | |
| - [ ] to incorporate a longer chain of goals; for example, go there and pick the package, then come back | |
| - [ ] to introduce additional learnt capabilities | |
| - [ ] to build more complex environments where the movement is not so straightforward | |
| """) | |
| # EVENTS: | |
| # gr.on(triggers=["load"], fn=load_image_on_start, outputs=output_env_image) | |
| # my_demo.load(fn=load_image_on_start, outputs=output_env_image) | |
| demo.load(fn=create_standing_animation, outputs=output_env) | |
| demo.load(fn=lambda: 'To which target do you want me to go?', outputs=status) | |
| # request_audio.stream(fn=get_text_request, inputs=request_audio, outputs=request_text) | |
| request_audio.stop_recording(fn=get_text_request, inputs=request_audio, outputs=request_text) | |
| request_text.change(fn=get_target_from_request, inputs=request_text, outputs=request_target) | |
| request_target.change(fn=move_agent, inputs=request_target, outputs=[output_env, status]) | |
| request_audio.stop_recording(lambda: None, outputs=request_audio) | |
| # ---------------------------- # | |
| # main | |
| # ---------------------------- # | |
| demo.launch() | |
| # device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| # torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| # model_id = "./models_for_proj/librispeech_asr_dummy" | |
| # model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True) | |
| # model.to(device) | |
| # processor = AutoProcessor.from_pretrained(model_id) | |
| # asr_pipe = pipeline( | |
| # "automatic-speech-recognition", | |
| # model=model, | |
| # tokenizer=processor.tokenizer, | |
| # feature_extractor=processor.feature_extractor, | |
| # max_new_tokens=128, | |
| # torch_dtype=torch_dtype, | |
| # device=device, | |
| # ) | |