models / app.py
Stevross's picture
Update app.py
cd92318
import gradio as gr
import requests
import json
from transformers import (
GPT2Tokenizer,
GPT2LMHeadModel,
AutoTokenizer,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
)
import torch
model_specs = [
# {
# "name": "distilgpt2",
# "tokenizer": GPT2Tokenizer.from_pretrained("distilgpt2"),
# "model": GPT2LMHeadModel.from_pretrained("distilgpt2"),
# },
# {
# "name": "openai-gpt",
# "tokenizer": AutoTokenizer.from_pretrained("openai-gpt"),
# "model": AutoModelForCausalLM.from_pretrained("openai-gpt"),
# },
{
"name": "OpenChatKit",
"tokenizer": AutoTokenizer.from_pretrained("togethercomputer/GPT-NeoXT-Chat-Base-20B"),
"model": AutoModelForCausalLM.from_pretrained(
"togethercomputer/GPT-NeoXT-Chat-Base-20B", torch_dtype=torch.float16
).to("cuda:0"),
},
# {
# "name": "flan-t5-xl",
# "tokenizer": AutoTokenizer.from_pretrained("google/flan-t5-xl"),
# "model": AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xl"),
# },
# {
# "name": "LLama",
# "tokenizer": AutoTokenizer.from_pretrained("decapoda-research/llama-7b-hf"),
# "model": AutoModelForCausalLM.from_pretrained("decapoda-research/llama-7b-hf"),
# },
]
def generate_response(model_name, input_text, speak_output):
selected_model = next(
(spec for spec in model_specs if spec["name"] == model_name), None
)
if not selected_model:
return "Invalid model selected."
tokenizer = selected_model["tokenizer"]
model = selected_model["model"]
if model_name in ["flan-t5-xl"]:
# For Seq2Seq models
input_ids = tokenizer.encode(input_text, return_tensors="pt")
output_ids = model.generate(input_ids)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
else:
# For CausalLM models
input_ids = tokenizer.encode(input_text, return_tensors="pt")
output_ids = model.generate(input_ids, max_length=input_ids.shape[1] + 50)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
if speak_output:
ELEVENLABS_API_URL = "https://api.elevenlabs.io/v1/text-to-speech/21m00Tcm4TlvDq8ikWAM"
HEADERS = {
"accept": "audio/mpeg",
"xi-api-key": "9360e4509988dc28ff03a5d43cb6941b",
"Content-Type": "application/json",
}
data = {
"text": response,
"voice_settings": {"stability": 0, "similarity_boost": 0},
}
response_audio = requests.post(ELEVENLABS_API_URL, headers=HEADERS, data=json.dumps(data))
with open("response_audio.mp3", "wb") as f:
f.write(response_audio.content)
return response
if __name__ == "__main__":
model_names = [spec["name"] for spec in model_specs]
model_dropdown = gr.inputs.Dropdown(model_names, label="Select Model")
text_input = gr.inputs.Textbox(lines=5, label="Input Text")
speak_output = gr.inputs.Checkbox(label="Speak output with ElevenLabs API")
gr.Interface(
generate_response,
inputs=[model_dropdown, text_input, speak_output],
outputs="text",
title="Chat Model Selection",
description="Select a chat model and enter text to generate a response.",
).launch()