Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,69 +6,31 @@ import time
|
|
| 6 |
import torch
|
| 7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 8 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from threading import Thread
|
| 10 |
os.environ["COQUI_TOS_AGREED"] = "1"
|
| 11 |
os.environ["TRAINER_TELEMETRY"]= "0"
|
| 12 |
# Constants
|
| 13 |
-
|
| 14 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 15 |
-
|
| 16 |
# Set the device to GPU or CPU
|
| 17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
|
| 19 |
-
# Load the
|
| 20 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
| 21 |
-
model =
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
<style>
|
| 31 |
-
h1 {
|
| 32 |
-
text-align: center;
|
| 33 |
-
}
|
| 34 |
-
</style>
|
| 35 |
-
"""
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
# Function to handle chat
|
| 39 |
-
def stream_chat(message, history, temperature=0.3, max_new_tokens=1024, top_p=1.0, top_k=20, penalty=1.2):
|
| 40 |
-
conversation = []
|
| 41 |
-
for prompt, answer in history:
|
| 42 |
-
conversation.extend([
|
| 43 |
-
{"role": "user", "content": prompt},
|
| 44 |
-
{"role": "assistant", "content": answer},
|
| 45 |
-
])
|
| 46 |
-
|
| 47 |
-
conversation.append({"role": "user", "content": message})
|
| 48 |
-
|
| 49 |
-
input_text = tokenizer.apply_chat_template(conversation, tokenize=False)
|
| 50 |
-
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
| 51 |
-
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
|
| 52 |
-
|
| 53 |
-
generate_kwargs = dict(
|
| 54 |
-
input_ids=inputs,
|
| 55 |
-
max_new_tokens=max_new_tokens,
|
| 56 |
-
do_sample=temperature != 0,
|
| 57 |
-
top_p=top_p,
|
| 58 |
-
top_k=top_k,
|
| 59 |
-
temperature=temperature,
|
| 60 |
-
streamer=streamer,
|
| 61 |
-
pad_token_id=10,
|
| 62 |
-
)
|
| 63 |
-
|
| 64 |
-
with torch.no_grad():
|
| 65 |
-
thread = Thread(target=model.generate, kwargs=generate_kwargs)
|
| 66 |
-
thread.start()
|
| 67 |
-
|
| 68 |
-
buffer = ""
|
| 69 |
-
for new_text in streamer:
|
| 70 |
-
buffer += new_text
|
| 71 |
-
yield buffer
|
| 72 |
|
| 73 |
#st.set_page_config(layout="wide")
|
| 74 |
# Load custom CSS to integrate Bootstrap, Font Awesome, and Google Fonts
|
|
@@ -150,34 +112,17 @@ with left:
|
|
| 150 |
st.markdown('''<h3><i class="fa fa-pencil"></i> Form 1</h3>''', unsafe_allow_html=True)
|
| 151 |
|
| 152 |
# Box 2: Form 1
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
top_p = st.slider("Top p", 0.0, 1.0, 1.0)
|
| 165 |
-
top_k = st.slider("Top k", 1, 20, 20)
|
| 166 |
-
penalty = st.slider("Repetition penalty", 0.0, 2.0, 1.2)
|
| 167 |
-
|
| 168 |
-
# Handle the chat logic
|
| 169 |
-
if st.button("Send"):
|
| 170 |
-
if user_input:
|
| 171 |
-
response = stream_chat(user_input, st.session_state['history'], temperature, max_new_tokens, top_p, top_k, penalty)
|
| 172 |
-
st.session_state['history'].append((user_input, next(response)))
|
| 173 |
-
for new_text in response:
|
| 174 |
-
st.session_state['history'][-1] = (user_input, new_text)
|
| 175 |
-
st.experimental_rerun()
|
| 176 |
-
|
| 177 |
-
# Display chat history
|
| 178 |
-
for prompt, answer in st.session_state['history']:
|
| 179 |
-
st.write(f"**User:** {prompt}")
|
| 180 |
-
st.write(f"**Mistral-Nemo:** {answer}")
|
| 181 |
|
| 182 |
|
| 183 |
with right:
|
|
|
|
| 6 |
import torch
|
| 7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 8 |
import streamlit as st
|
| 9 |
+
|
| 10 |
+
from transformers import AutoModelForSeq2SeqLM
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
from threading import Thread
|
| 14 |
os.environ["COQUI_TOS_AGREED"] = "1"
|
| 15 |
os.environ["TRAINER_TELEMETRY"]= "0"
|
| 16 |
# Constants
|
| 17 |
+
|
| 18 |
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 19 |
+
|
| 20 |
# Set the device to GPU or CPU
|
| 21 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 22 |
|
| 23 |
+
# Load the tokenizer and model
|
| 24 |
+
tokenizer = AutoTokenizer.from_pretrained("merve/chatgpt-prompts-bart-long")
|
| 25 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("merve/chatgpt-prompts-bart-long", from_tf=True).to("cuda" if torch.cuda.is_available() else "cpu")
|
| 26 |
+
|
| 27 |
+
# Function to generate the prompt based on the persona
|
| 28 |
+
def generate(prompt):
|
| 29 |
+
batch = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 30 |
+
generated_ids = model.generate(batch["input_ids"], max_new_tokens=150)
|
| 31 |
+
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
| 32 |
+
return output[0]
|
| 33 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
#st.set_page_config(layout="wide")
|
| 36 |
# Load custom CSS to integrate Bootstrap, Font Awesome, and Google Fonts
|
|
|
|
| 112 |
st.markdown('''<h3><i class="fa fa-pencil"></i> Form 1</h3>''', unsafe_allow_html=True)
|
| 113 |
|
| 114 |
# Box 2: Form 1
|
| 115 |
+
persona = st.text_input("Input a persona, e.g. photographer", value="photographer")
|
| 116 |
+
|
| 117 |
+
# Button to trigger generation
|
| 118 |
+
if st.button("Generate Prompt"):
|
| 119 |
+
if persona:
|
| 120 |
+
with st.spinner("Generating..."):
|
| 121 |
+
result = generate(persona)
|
| 122 |
+
st.text_area("Generated Prompt", value=result, height=200)
|
| 123 |
+
else:
|
| 124 |
+
st.error("Please enter a persona to generate a prompt.")
|
| 125 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
with right:
|