smartdigitalnetworks commited on
Commit
d540e64
·
verified ·
1 Parent(s): 7f59467

Upload generativetext.py

Browse files
Files changed (1) hide show
  1. generativetext.py +123 -0
generativetext.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ from collections.abc import Iterator
5
+ from threading import Thread
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
+
12
+ DESCRIPTION = "# 💬 Generative Text"
13
+
14
+ if not torch.cuda.is_available():
15
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
16
+
17
+ MAX_MAX_NEW_TOKENS = 2048
18
+ DEFAULT_MAX_NEW_TOKENS = 1024
19
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
20
+
21
+ if torch.cuda.is_available():
22
+ model_id = "mistralai/Mistral-7B-Instruct-v0.3"
23
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
24
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
25
+
26
+
27
+ @spaces.GPU
28
+ def generate(
29
+ message: str,
30
+ chat_history: list[dict],
31
+ max_new_tokens: int = 1024,
32
+ temperature: float = 0.6,
33
+ top_p: float = 0.9,
34
+ top_k: int = 50,
35
+ repetition_penalty: float = 1.2,
36
+ ) -> Iterator[str]:
37
+ conversation = [*chat_history, {"role": "user", "content": message}]
38
+
39
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
40
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
41
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
42
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
43
+ input_ids = input_ids.to(model.device)
44
+
45
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
46
+ generate_kwargs = dict(
47
+ {"input_ids": input_ids},
48
+ streamer=streamer,
49
+ max_new_tokens=max_new_tokens,
50
+ do_sample=True,
51
+ top_p=top_p,
52
+ top_k=top_k,
53
+ temperature=temperature,
54
+ num_beams=1,
55
+ repetition_penalty=repetition_penalty,
56
+ )
57
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
58
+ t.start()
59
+
60
+ outputs = []
61
+ for text in streamer:
62
+ outputs.append(text)
63
+ yield "".join(outputs)
64
+
65
+
66
+ demo = gr.ChatInterface(
67
+ fn=generate,
68
+ additional_inputs=[
69
+ gr.Slider(
70
+ label="Max new tokens",
71
+ minimum=1,
72
+ maximum=MAX_MAX_NEW_TOKENS,
73
+ step=1,
74
+ value=DEFAULT_MAX_NEW_TOKENS,
75
+ ),
76
+ gr.Slider(
77
+ label="Temperature",
78
+ minimum=0.1,
79
+ maximum=4.0,
80
+ step=0.1,
81
+ value=0.6,
82
+ ),
83
+ gr.Slider(
84
+ label="Top-p (nucleus sampling)",
85
+ minimum=0.05,
86
+ maximum=1.0,
87
+ step=0.05,
88
+ value=0.9,
89
+ ),
90
+ gr.Slider(
91
+ label="Top-k",
92
+ minimum=1,
93
+ maximum=1000,
94
+ step=1,
95
+ value=50,
96
+ ),
97
+ gr.Slider(
98
+ label="Repetition penalty",
99
+ minimum=1.0,
100
+ maximum=2.0,
101
+ step=0.05,
102
+ value=1.2,
103
+ ),
104
+ ],
105
+ stop_btn=None,
106
+ examples=[
107
+ ["Hello is there anybody in there? Just nod if you can hear me? Is there anybody home?"],
108
+ ["Can you explain briefly to me who is Banksy?"],
109
+ ["Explain the plot of Jerry Hesketh's HTC VIVE'Virtual Reality Experiece Sci-Fi Galaxy and SciFiGalaxy.com in 5 sentences."],
110
+ ["Is the Patrick Winston MIT 6.034 Artificial Intelligence course any good? Where might I find more info?"],
111
+ ["Write a 100-word article on Written by Howard Suber founding chair of the Film & Television Producers Program at UCLA and his Book and Video Series THE POWER OF FILM"],
112
+ ["What is ZionDub.com? Tell me three key features about what is going on there and Contact information."],
113
+ ["Where can I download the Augmented Reality application Ultratime. Who made it and where has it been used?"],
114
+ ["Who holds the Trademark and/or Copyright for Smart Digital Telelvision and who is the author, when was it founded in 12 sentences."],
115
+ ["Provide information on how to investing in Smart Digital Networks, Inc. Can you provide information and links to thier products."],
116
+ ],
117
+ type="messages",
118
+ description=DESCRIPTION,
119
+ css_paths="style.css",
120
+ )
121
+
122
+ if __name__ == "__main__":
123
+ demo.queue(max_size=20).launch()