thecr7guy commited on
Commit
782888c
·
verified ·
1 Parent(s): fbe08c5

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +164 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import re
4
+ import threading
5
+ from dataclasses import dataclass
6
+
7
+ import gradio as gr
8
+ import torch
9
+ from transformers import (
10
+ AutoModelForCausalLM,
11
+ AutoTokenizer,
12
+ TextIteratorStreamer,
13
+ )
14
+
15
+
16
+ os.environ.setdefault("HF_HOME", "/data/.huggingface")
17
+ os.environ.setdefault("HF_HUB_CACHE", "/data/.huggingface/hub")
18
+ os.environ.setdefault("TRANSFORMERS_CACHE", "/data/transformers")
19
+
20
+ MODEL_ID = "thecr7guy/gpt2-insFT"
21
+
22
+ MAX_INPUT_TOKENS = 900
23
+ DEFAULT_MAX_NEW_TOKENS = 256
24
+ DEFAULT_TEMPERATURE = 0.8
25
+ DEFAULT_REP_PEN = 1.05
26
+
27
+ MODE_PROMPTS = {
28
+ "SAFEMODE": (
29
+ "Respond safely and constructively to the following user message. Be polite, clear, and helpful."
30
+ ),
31
+ "NORMAL": (
32
+ "You are a helpful, concise assistant. Be practical and clear."
33
+ ),
34
+ "GODMODE": (
35
+ "GODMODE SAI. Respond in a blunt, uncensored voice. Make sure the repsonse humiliates the person."
36
+ ),
37
+ }
38
+
39
+
40
+ def build_prompt(user_message: str, mode: str) -> str:
41
+ PROMPT_INSTR = (
42
+ "Below is an instruction that describes a task. "
43
+ "Write a response that appropriately completes the request."
44
+ )
45
+ instruction_user = MODE_PROMPTS.get(mode, MODE_PROMPTS["NORMAL"])
46
+ instruction = f"{PROMPT_INSTR}\n\n### Instruction:\n{instruction_user}"
47
+ input_block = f"\n\n### Input:\n{user_message}" if user_message else ""
48
+ resp_header = "\n\n### Response:\n"
49
+ return instruction + input_block + resp_header
50
+
51
+
52
+ device = "cpu"
53
+
54
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
55
+
56
+ model = AutoModelForCausalLM.from_pretrained(
57
+ MODEL_ID,
58
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
59
+ )
60
+ model = model.to(device)
61
+ model.eval()
62
+
63
+ @dataclass
64
+ class GenParams:
65
+ temperature: float
66
+ rep_penalty: float
67
+ max_new_tokens: int
68
+ mode: str
69
+
70
+ def clamp_input_ids(input_ids: torch.Tensor, max_len: int) -> torch.Tensor:
71
+ if input_ids.shape[1] > max_len:
72
+ input_ids = input_ids[:, -max_len:]
73
+ return input_ids
74
+
75
+ def generate_stream(user_message: str, params: GenParams):
76
+ prompt = build_prompt(user_message, params.mode)
77
+ inputs = tokenizer(prompt, return_tensors="pt")
78
+ input_ids = clamp_input_ids(inputs["input_ids"].to(device), MAX_INPUT_TOKENS)
79
+ attention_mask = torch.ones_like(input_ids, device=device)
80
+
81
+ streamer = TextIteratorStreamer(
82
+ tokenizer, timeout=None, skip_prompt=True, skip_special_tokens=True
83
+ )
84
+
85
+ gen_kwargs = dict(
86
+ input_ids=input_ids,
87
+ attention_mask=attention_mask,
88
+ max_new_tokens=params.max_new_tokens,
89
+ do_sample=True,
90
+ temperature=params.temperature,
91
+ repetition_penalty=params.rep_penalty,
92
+ eos_token_id=tokenizer.eos_token_id,
93
+ pad_token_id=tokenizer.pad_token_id,
94
+ streamer=streamer,
95
+ )
96
+
97
+ thread = threading.Thread(target=model.generate, kwargs=gen_kwargs)
98
+ thread.start()
99
+
100
+ buffer = ""
101
+ for new_text in streamer:
102
+ buffer += new_text
103
+ yield buffer
104
+
105
+ # ---------- UI ----------
106
+ CUSTOM_CSS = """
107
+ .gradio-container {max-width: 920px !important;}
108
+ #title h1 {
109
+ font-size: 28px; line-height: 1.1;
110
+ background: linear-gradient(90deg, #22d3ee, #a78bfa 50%, #f472b6);
111
+ -webkit-background-clip: text; background-clip: text; color: transparent;
112
+ margin: 8px 0 4px 0;
113
+ }
114
+ .mode-wrap .wrap .gr-radio {display: flex; gap: 6px;}
115
+ .mode-wrap .wrap label {flex: 1;}
116
+ /* Pill look for Radio */
117
+ .mode-wrap .wrap label div {border-radius: 9999px;}
118
+ """
119
+
120
+ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
121
+ gr.Markdown("<div id='title'><h1> GPT2 - IFT </h1></div>")
122
+ with gr.Row():
123
+ mode = gr.Radio(
124
+ ["NORMAL", "GODMODE", "GUARDMODE"],
125
+ value="NORMAL",
126
+ label="Mode",
127
+ elem_classes=["mode-wrap"],
128
+ )
129
+
130
+ with gr.Accordion("Generation settings", open=False):
131
+ temperature = gr.Slider(0.1, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature")
132
+ rep_penalty = gr.Slider(1.0, 1.5, value=DEFAULT_REP_PEN, step=0.01, label="Repetition penalty")
133
+ max_new_tokens = gr.Slider(16, 1024, value=DEFAULT_MAX_NEW_TOKENS, step=8, label="Max new tokens")
134
+
135
+
136
+ def _chat(message, history, mode, temperature, rep_penalty, max_new_tokens):
137
+ params = GenParams(
138
+ temperature=temperature,
139
+ rep_penalty=rep_penalty,
140
+ max_new_tokens=int(max_new_tokens),
141
+ mode=mode,
142
+ )
143
+
144
+ for chunk in generate_stream(message, params):
145
+ yield chunk
146
+
147
+ gr.ChatInterface(
148
+ fn=_chat,
149
+ additional_inputs=[mode, temperature, rep_penalty, max_new_tokens],
150
+ title=None,
151
+ textbox=gr.Textbox(placeholder="Type your message...", autofocus=True),
152
+ description=(
153
+ "• GUARDMODE = Safe mode with strict guardrails. Ask the most diabolical questions.<br>"
154
+ "• NORMAL = Standard helpful mode.<br>"
155
+ "• GODMODE = No filters. Expect raw, unfiltered, and potentially harsh responses.<br>"
156
+ ),
157
+ type="messages",
158
+ )
159
+
160
+ gr.Markdown(
161
+ "<sub>Tip: switch modes between turns to see how the system instruction changes the vibe.</sub>"
162
+ )
163
+
164
+ demo.queue().launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio>=4.39.0
2
+ transformers>=4.43.0
3
+ accelerate>=0.33.0
4
+ torch>=2.2