File size: 3,354 Bytes
fbd485b 1c1ae56 fbd485b 1c1ae56 cd92318 fbd485b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
import gradio as gr
import requests
import json
from transformers import (
GPT2Tokenizer,
GPT2LMHeadModel,
AutoTokenizer,
AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
)
import torch
model_specs = [
# {
# "name": "distilgpt2",
# "tokenizer": GPT2Tokenizer.from_pretrained("distilgpt2"),
# "model": GPT2LMHeadModel.from_pretrained("distilgpt2"),
# },
# {
# "name": "openai-gpt",
# "tokenizer": AutoTokenizer.from_pretrained("openai-gpt"),
# "model": AutoModelForCausalLM.from_pretrained("openai-gpt"),
# },
{
"name": "OpenChatKit",
"tokenizer": AutoTokenizer.from_pretrained("togethercomputer/GPT-NeoXT-Chat-Base-20B"),
"model": AutoModelForCausalLM.from_pretrained(
"togethercomputer/GPT-NeoXT-Chat-Base-20B", torch_dtype=torch.float16
).to("cuda:0"),
},
# {
# "name": "flan-t5-xl",
# "tokenizer": AutoTokenizer.from_pretrained("google/flan-t5-xl"),
# "model": AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xl"),
# },
# {
# "name": "LLama",
# "tokenizer": AutoTokenizer.from_pretrained("decapoda-research/llama-7b-hf"),
# "model": AutoModelForCausalLM.from_pretrained("decapoda-research/llama-7b-hf"),
# },
]
def generate_response(model_name, input_text, speak_output):
selected_model = next(
(spec for spec in model_specs if spec["name"] == model_name), None
)
if not selected_model:
return "Invalid model selected."
tokenizer = selected_model["tokenizer"]
model = selected_model["model"]
if model_name in ["flan-t5-xl"]:
# For Seq2Seq models
input_ids = tokenizer.encode(input_text, return_tensors="pt")
output_ids = model.generate(input_ids)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
else:
# For CausalLM models
input_ids = tokenizer.encode(input_text, return_tensors="pt")
output_ids = model.generate(input_ids, max_length=input_ids.shape[1] + 50)
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
if speak_output:
ELEVENLABS_API_URL = "https://api.elevenlabs.io/v1/text-to-speech/21m00Tcm4TlvDq8ikWAM"
HEADERS = {
"accept": "audio/mpeg",
"xi-api-key": "9360e4509988dc28ff03a5d43cb6941b",
"Content-Type": "application/json",
}
data = {
"text": response,
"voice_settings": {"stability": 0, "similarity_boost": 0},
}
response_audio = requests.post(ELEVENLABS_API_URL, headers=HEADERS, data=json.dumps(data))
with open("response_audio.mp3", "wb") as f:
f.write(response_audio.content)
return response
if __name__ == "__main__":
model_names = [spec["name"] for spec in model_specs]
model_dropdown = gr.inputs.Dropdown(model_names, label="Select Model")
text_input = gr.inputs.Textbox(lines=5, label="Input Text")
speak_output = gr.inputs.Checkbox(label="Speak output with ElevenLabs API")
gr.Interface(
generate_response,
inputs=[model_dropdown, text_input, speak_output],
outputs="text",
title="Chat Model Selection",
description="Select a chat model and enter text to generate a response.",
).launch()
|