Stevross commited on
Commit
fbd485b
·
1 Parent(s): e4c945b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import json
4
+ from transformers import (
5
+ GPT2Tokenizer,
6
+ GPT2LMHeadModel,
7
+ AutoTokenizer,
8
+ AutoModelForCausalLM,
9
+ AutoModelForSeq2SeqLM,
10
+ )
11
+ import torch
12
+
13
+ model_specs = [
14
+ {
15
+ "name": "distilgpt2",
16
+ "tokenizer": GPT2Tokenizer.from_pretrained("distilgpt2"),
17
+ "model": GPT2LMHeadModel.from_pretrained("distilgpt2"),
18
+ },
19
+ {
20
+ "name": "openai-gpt",
21
+ "tokenizer": AutoTokenizer.from_pretrained("openai-gpt"),
22
+ "model": AutoModelForCausalLM.from_pretrained("openai-gpt"),
23
+ },
24
+ {
25
+ "name": "OpenChatKit",
26
+ "tokenizer": AutoTokenizer.from_pretrained("togethercomputer/GPT-NeoXT-Chat-Base-20B"),
27
+ "model": AutoModelForCausalLM.from_pretrained(
28
+ "togethercomputer/GPT-NeoXT-Chat-Base-20B", torch_dtype=torch.float16
29
+ ).to("cuda:0"),
30
+ },
31
+ {
32
+ "name": "flan-t5-xl",
33
+ "tokenizer": AutoTokenizer.from_pretrained("google/flan-t5-xl"),
34
+ "model": AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-xl"),
35
+ },
36
+ {
37
+ "name": "LLama",
38
+ "tokenizer": AutoTokenizer.from_pretrained("decapoda-research/llama-7b-hf"),
39
+ "model": AutoModelForCausalLM.from_pretrained("decapoda-research/llama-7b-hf"),
40
+ },
41
+ ]
42
+
43
+ def generate_response(model_name, input_text, speak_output):
44
+ selected_model = next(
45
+ (spec for spec in model_specs if spec["name"] == model_name), None
46
+ )
47
+ if not selected_model:
48
+ return "Invalid model selected."
49
+
50
+ tokenizer = selected_model["tokenizer"]
51
+ model = selected_model["model"]
52
+
53
+ if model_name in ["flan-t5-xl"]:
54
+ # For Seq2Seq models
55
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
56
+ output_ids = model.generate(input_ids)
57
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
58
+ else:
59
+ # For CausalLM models
60
+ input_ids = tokenizer.encode(input_text, return_tensors="pt")
61
+ output_ids = model.generate(input_ids, max_length=input_ids.shape[1] + 50)
62
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
63
+
64
+ if speak_output:
65
+ ELEVENLABS_API_URL = "https://api.elevenlabs.io/v1/text-to-speech/21m00Tcm4TlvDq8ikWAM"
66
+ HEADERS = {
67
+ "accept": "audio/mpeg",
68
+ "xi-api-key": "9360e4509988dc28ff03a5d43cb6941b",
69
+ "Content-Type": "application/json",
70
+ }
71
+ data = {
72
+ "text": response,
73
+ "voice_settings": {"stability": 0, "similarity_boost": 0},
74
+ }
75
+ response_audio = requests.post(ELEVENLABS_API_URL, headers=HEADERS, data=json.dumps(data))
76
+
77
+ with open("response_audio.mp3", "wb") as f:
78
+ f.write(response_audio.content)
79
+
80
+ return response
81
+
82
+ if __name__ == "__main__":
83
+ model_names = [spec["name"] for spec in model_specs]
84
+ model_dropdown = gr.inputs.Dropdown(model_names, label="Select Model")
85
+ text_input = gr.inputs.Textbox(lines=5, label="Input Text")
86
+ speak_output = gr.inputs.Checkbox(label="Speak output with ElevenLabs API")
87
+
88
+ gr.Interface(
89
+ generate_response,
90
+ inputs=[model_dropdown, text_input, speak_output],
91
+ outputs="text",
92
+ title="Chat Model Selection",
93
+ description="Select a chat model and enter text to generate a response.",
94
+ ).launch()