d3dname commited on
Commit
800cb10
·
verified ·
1 Parent(s): 7635e03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -83
app.py CHANGED
@@ -6,69 +6,31 @@ import time
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
  import streamlit as st
 
 
 
 
9
  from threading import Thread
10
  os.environ["COQUI_TOS_AGREED"] = "1"
11
  os.environ["TRAINER_TELEMETRY"]= "0"
12
  # Constants
13
- MODEL_LIST = ["mistralai/Mistral-Nemo-Instruct-2407"]
14
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
15
- MODEL = os.environ.get("MODEL_ID", "mistralai/Mistral-Nemo-Instruct-2407")
16
  # Set the device to GPU or CPU
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
- # Load the model and tokenizer
20
- tokenizer = AutoTokenizer.from_pretrained(MODEL)
21
- model = AutoModelForCausalLM.from_pretrained(
22
- MODEL,
23
- torch_dtype=torch.bfloat16,
24
- device_map="auto",
25
- ignore_mismatched_sizes=True
26
- ).to(device)
27
-
28
- # CSS styles
29
- CSS = """
30
- <style>
31
- h1 {
32
- text-align: center;
33
- }
34
- </style>
35
- """
36
-
37
-
38
- # Function to handle chat
39
- def stream_chat(message, history, temperature=0.3, max_new_tokens=1024, top_p=1.0, top_k=20, penalty=1.2):
40
- conversation = []
41
- for prompt, answer in history:
42
- conversation.extend([
43
- {"role": "user", "content": prompt},
44
- {"role": "assistant", "content": answer},
45
- ])
46
-
47
- conversation.append({"role": "user", "content": message})
48
-
49
- input_text = tokenizer.apply_chat_template(conversation, tokenize=False)
50
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
51
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
52
-
53
- generate_kwargs = dict(
54
- input_ids=inputs,
55
- max_new_tokens=max_new_tokens,
56
- do_sample=temperature != 0,
57
- top_p=top_p,
58
- top_k=top_k,
59
- temperature=temperature,
60
- streamer=streamer,
61
- pad_token_id=10,
62
- )
63
-
64
- with torch.no_grad():
65
- thread = Thread(target=model.generate, kwargs=generate_kwargs)
66
- thread.start()
67
-
68
- buffer = ""
69
- for new_text in streamer:
70
- buffer += new_text
71
- yield buffer
72
 
73
  #st.set_page_config(layout="wide")
74
  # Load custom CSS to integrate Bootstrap, Font Awesome, and Google Fonts
@@ -150,34 +112,17 @@ with left:
150
  st.markdown('''<h3><i class="fa fa-pencil"></i> Form 1</h3>''', unsafe_allow_html=True)
151
 
152
  # Box 2: Form 1
153
- # Chat history
154
- if 'history' not in st.session_state:
155
- st.session_state['history'] = []
156
-
157
- # Chat input
158
- user_input = st.text_input("Your Message:", "")
159
-
160
- # Chat parameters
161
- with st.expander("⚙️ Parameters"):
162
- temperature = st.slider("Temperature", 0.0, 1.0, 0.3)
163
- max_new_tokens = st.slider("Max new tokens", 128, 8192, 1024)
164
- top_p = st.slider("Top p", 0.0, 1.0, 1.0)
165
- top_k = st.slider("Top k", 1, 20, 20)
166
- penalty = st.slider("Repetition penalty", 0.0, 2.0, 1.2)
167
-
168
- # Handle the chat logic
169
- if st.button("Send"):
170
- if user_input:
171
- response = stream_chat(user_input, st.session_state['history'], temperature, max_new_tokens, top_p, top_k, penalty)
172
- st.session_state['history'].append((user_input, next(response)))
173
- for new_text in response:
174
- st.session_state['history'][-1] = (user_input, new_text)
175
- st.experimental_rerun()
176
-
177
- # Display chat history
178
- for prompt, answer in st.session_state['history']:
179
- st.write(f"**User:** {prompt}")
180
- st.write(f"**Mistral-Nemo:** {answer}")
181
 
182
 
183
  with right:
 
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
  import streamlit as st
9
+
10
+ from transformers import AutoModelForSeq2SeqLM
11
+ import torch
12
+
13
  from threading import Thread
14
  os.environ["COQUI_TOS_AGREED"] = "1"
15
  os.environ["TRAINER_TELEMETRY"]= "0"
16
  # Constants
17
+
18
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
19
+
20
  # Set the device to GPU or CPU
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
 
23
+ # Load the tokenizer and model
24
+ tokenizer = AutoTokenizer.from_pretrained("merve/chatgpt-prompts-bart-long")
25
+ model = AutoModelForSeq2SeqLM.from_pretrained("merve/chatgpt-prompts-bart-long", from_tf=True).to("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ # Function to generate the prompt based on the persona
28
+ def generate(prompt):
29
+ batch = tokenizer(prompt, return_tensors="pt").to(model.device)
30
+ generated_ids = model.generate(batch["input_ids"], max_new_tokens=150)
31
+ output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
32
+ return output[0]
33
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  #st.set_page_config(layout="wide")
36
  # Load custom CSS to integrate Bootstrap, Font Awesome, and Google Fonts
 
112
  st.markdown('''<h3><i class="fa fa-pencil"></i> Form 1</h3>''', unsafe_allow_html=True)
113
 
114
  # Box 2: Form 1
115
+ persona = st.text_input("Input a persona, e.g. photographer", value="photographer")
116
+
117
+ # Button to trigger generation
118
+ if st.button("Generate Prompt"):
119
+ if persona:
120
+ with st.spinner("Generating..."):
121
+ result = generate(persona)
122
+ st.text_area("Generated Prompt", value=result, height=200)
123
+ else:
124
+ st.error("Please enter a persona to generate a prompt.")
125
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
 
128
  with right: