fffiloni's picture
Update app.py
8160ec1
import gradio as gr
import random
import time
from PIL import Image
import numpy as np
import soundfile as sf
import imageio
from transformers import OpenAiAgent
from diffusers.utils import export_to_video
global agent
def load_agent(openai_api_key):
global agent
agent = OpenAiAgent(model="text-davinci-003", api_key=openai_api_key)
return "Agent is ready"
import torch
def is_audio_tensor(tensor):
# Check tensor properties
if not isinstance(tensor, torch.Tensor):
return False
# Check data type
if tensor.dtype != torch.float32 and tensor.dtype != torch.float64:
return False
# Check shape
if tensor.dim() != 1:
return False
# Check range of values
max_value = tensor.max().item()
min_value = tensor.min().item()
if max_value > 1.0 or min_value < -1.0:
return False
# Additional checks can be added if needed
# For example, checking sample rate or duration
# If all checks passed, assume it's an audio tensor
return True
def is_list_of_arrays(lst):
if isinstance(lst, list):
for item in lst:
if not isinstance(item, np.ndarray):
return False
print('Response is a list of numpy arrays')
return True
print('Response is NOT a list of numpy arrays')
return False
def respond(message, chat_history):
response = agent.chat(message)
check_audio = is_audio_tensor(response)
check_video = is_list_of_arrays(response)
if check_audio == True:
sf.write("speech_converted.wav", response.numpy(), samplerate=16000)
response = ("speech_converted.wav",)
# Check if the variable is a PIL image object
elif isinstance(response, Image.Image):
print("The variable is a PIL image object.")
response = response.save('image.jpg')
response = ('image.jpg',)
elif check_video == True:
print('The response is a video array')
response = (export_to_video(response),)
else:
print("The response is simple text")
response = response
bot_message = response
print(response)
#bot_message = ("https://storage.googleapis.com/gtv-videos-bucket/sample/BigBuckBunny.mp4",)
chat_history.append((message, bot_message))
time.sleep(1)
return "", chat_history
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
openai_api_key = gr.Textbox(label="OpenAI API key", type="password")
with gr.Row():
agent_status = gr.Textbox("Status")
load_agent_btn = gr.Button("Load agent")
with gr.Column():
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
load_agent_btn.click(load_agent, [openai_api_key], [agent_status])
msg.submit(respond, [msg, chatbot], [msg, chatbot])
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch()