Update app.py
Browse files
app.py
CHANGED
|
@@ -5,13 +5,12 @@ import gradio as gr
|
|
| 5 |
import spaces
|
| 6 |
import torch
|
| 7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 8 |
-
from typing import List, Dict, Optional, Tuple
|
| 9 |
|
| 10 |
DESCRIPTION = """
|
| 11 |
-
#
|
| 12 |
"""
|
| 13 |
|
| 14 |
-
css =
|
| 15 |
h1 {
|
| 16 |
text-align: center;
|
| 17 |
display: block;
|
|
@@ -31,76 +30,37 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
|
| 31 |
|
| 32 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 33 |
|
| 34 |
-
model_id = "
|
| 35 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 36 |
model = AutoModelForCausalLM.from_pretrained(
|
| 37 |
model_id,
|
| 38 |
device_map="auto",
|
| 39 |
torch_dtype=torch.bfloat16,
|
| 40 |
)
|
| 41 |
-
model.config.sliding_window = 4096
|
| 42 |
model.eval()
|
| 43 |
|
| 44 |
-
# Set the pad token ID if it's not already set
|
| 45 |
-
if tokenizer.pad_token_id is None:
|
| 46 |
-
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 47 |
-
|
| 48 |
-
# Define roles for the chat
|
| 49 |
-
class Role:
|
| 50 |
-
SYSTEM = "system"
|
| 51 |
-
USER = "user"
|
| 52 |
-
ASSISTANT = "assistant"
|
| 53 |
-
|
| 54 |
-
# Default system message
|
| 55 |
-
default_system = "You are a helpful assistant."
|
| 56 |
-
|
| 57 |
-
def clear_session() -> List:
|
| 58 |
-
return "", []
|
| 59 |
-
|
| 60 |
-
def modify_system_session(system: str) -> Tuple[str, str, List]:
|
| 61 |
-
if system is None or len(system) == 0:
|
| 62 |
-
system = default_system
|
| 63 |
-
return system, system, []
|
| 64 |
-
|
| 65 |
-
def history_to_messages(history: List, system: str) -> List[Dict]:
|
| 66 |
-
messages = [{'role': Role.SYSTEM, 'content': system}]
|
| 67 |
-
for h in history:
|
| 68 |
-
messages.append({'role': Role.USER, 'content': h[0]})
|
| 69 |
-
messages.append({'role': Role.ASSISTANT, 'content': h[1]})
|
| 70 |
-
return messages
|
| 71 |
|
| 72 |
@spaces.GPU(duration=120)
|
| 73 |
def generate(
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
system: str,
|
| 77 |
max_new_tokens: int = 1024,
|
| 78 |
temperature: float = 0.6,
|
| 79 |
top_p: float = 0.9,
|
| 80 |
top_k: int = 50,
|
| 81 |
repetition_penalty: float = 1.2,
|
| 82 |
) -> Iterator[str]:
|
| 83 |
-
|
| 84 |
-
query = ''
|
| 85 |
-
if history is None:
|
| 86 |
-
history = []
|
| 87 |
-
|
| 88 |
-
# Convert history to messages
|
| 89 |
-
messages = history_to_messages(history, system)
|
| 90 |
-
messages.append({'role': Role.USER, 'content': query})
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
)
|
| 98 |
-
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
| 99 |
|
| 100 |
-
# Set up the streamer for real-time text generation
|
| 101 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
| 102 |
generate_kwargs = dict(
|
| 103 |
-
|
| 104 |
streamer=streamer,
|
| 105 |
max_new_tokens=max_new_tokens,
|
| 106 |
do_sample=True,
|
|
@@ -109,12 +69,10 @@ def generate(
|
|
| 109 |
temperature=temperature,
|
| 110 |
num_beams=1,
|
| 111 |
repetition_penalty=repetition_penalty,
|
| 112 |
-
pad_token_id=tokenizer.pad_token_id,
|
| 113 |
)
|
| 114 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
| 115 |
t.start()
|
| 116 |
|
| 117 |
-
# Stream the output tokens
|
| 118 |
outputs = []
|
| 119 |
for text in streamer:
|
| 120 |
outputs.append(text)
|
|
@@ -124,7 +82,6 @@ def generate(
|
|
| 124 |
demo = gr.ChatInterface(
|
| 125 |
fn=generate,
|
| 126 |
additional_inputs=[
|
| 127 |
-
gr.Textbox(label="System Message", value=default_system, lines=2),
|
| 128 |
gr.Slider(
|
| 129 |
label="Max new tokens",
|
| 130 |
minimum=1,
|
|
@@ -163,12 +120,14 @@ demo = gr.ChatInterface(
|
|
| 163 |
],
|
| 164 |
stop_btn=None,
|
| 165 |
examples=[
|
| 166 |
-
["Write a Python function to reverses a string if it's length is a multiple of 4."],
|
| 167 |
-
["
|
| 168 |
-
["
|
| 169 |
["What happens when the sun goes down?"],
|
| 170 |
],
|
|
|
|
| 171 |
cache_examples=False,
|
|
|
|
| 172 |
description=DESCRIPTION,
|
| 173 |
css=css,
|
| 174 |
fill_height=True,
|
|
@@ -176,4 +135,4 @@ demo = gr.ChatInterface(
|
|
| 176 |
|
| 177 |
|
| 178 |
if __name__ == "__main__":
|
| 179 |
-
demo.queue(max_size=20).launch(
|
|
|
|
| 5 |
import spaces
|
| 6 |
import torch
|
| 7 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
|
|
|
| 8 |
|
| 9 |
DESCRIPTION = """
|
| 10 |
+
# LlamaEXP
|
| 11 |
"""
|
| 12 |
|
| 13 |
+
css ='''
|
| 14 |
h1 {
|
| 15 |
text-align: center;
|
| 16 |
display: block;
|
|
|
|
| 30 |
|
| 31 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 32 |
|
| 33 |
+
model_id = "prithivMLmods/Llama-Express.1"
|
| 34 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 35 |
model = AutoModelForCausalLM.from_pretrained(
|
| 36 |
model_id,
|
| 37 |
device_map="auto",
|
| 38 |
torch_dtype=torch.bfloat16,
|
| 39 |
)
|
|
|
|
| 40 |
model.eval()
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
@spaces.GPU(duration=120)
|
| 44 |
def generate(
|
| 45 |
+
message: str,
|
| 46 |
+
chat_history: list[dict],
|
|
|
|
| 47 |
max_new_tokens: int = 1024,
|
| 48 |
temperature: float = 0.6,
|
| 49 |
top_p: float = 0.9,
|
| 50 |
top_k: int = 50,
|
| 51 |
repetition_penalty: float = 1.2,
|
| 52 |
) -> Iterator[str]:
|
| 53 |
+
conversation = [*chat_history, {"role": "user", "content": message}]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
+
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
| 56 |
+
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
| 57 |
+
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
| 58 |
+
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
| 59 |
+
input_ids = input_ids.to(model.device)
|
|
|
|
|
|
|
| 60 |
|
|
|
|
| 61 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
| 62 |
generate_kwargs = dict(
|
| 63 |
+
{"input_ids": input_ids},
|
| 64 |
streamer=streamer,
|
| 65 |
max_new_tokens=max_new_tokens,
|
| 66 |
do_sample=True,
|
|
|
|
| 69 |
temperature=temperature,
|
| 70 |
num_beams=1,
|
| 71 |
repetition_penalty=repetition_penalty,
|
|
|
|
| 72 |
)
|
| 73 |
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
| 74 |
t.start()
|
| 75 |
|
|
|
|
| 76 |
outputs = []
|
| 77 |
for text in streamer:
|
| 78 |
outputs.append(text)
|
|
|
|
| 82 |
demo = gr.ChatInterface(
|
| 83 |
fn=generate,
|
| 84 |
additional_inputs=[
|
|
|
|
| 85 |
gr.Slider(
|
| 86 |
label="Max new tokens",
|
| 87 |
minimum=1,
|
|
|
|
| 120 |
],
|
| 121 |
stop_btn=None,
|
| 122 |
examples=[
|
| 123 |
+
["Write a Python function to reverses a string if it's length is a multiple of 4. def reverse_string(str1): if len(str1) % 4 == 0: return ''.join(reversed(str1)) return str1 print(reverse_string('abcd')) print(reverse_string('python')) "],
|
| 124 |
+
["Rectangle $ABCD$ is the base of pyramid $PABCD$. If $AB = 10$, $BC = 5$, $\overline{PA}\perp \text{plane } ABCD$, and $PA = 8$, then what is the volume of $PABCD$?"],
|
| 125 |
+
["Difference between List comprehension and Lambda in Python lst = [x ** 2 for x in range (1, 11) if x % 2 == 1] print(lst)"],
|
| 126 |
["What happens when the sun goes down?"],
|
| 127 |
],
|
| 128 |
+
cache_examp
|
| 129 |
cache_examples=False,
|
| 130 |
+
type="messages",
|
| 131 |
description=DESCRIPTION,
|
| 132 |
css=css,
|
| 133 |
fill_height=True,
|
|
|
|
| 135 |
|
| 136 |
|
| 137 |
if __name__ == "__main__":
|
| 138 |
+
demo.queue(max_size=20).launch()
|