Tonic commited on
Commit
ba3e817
·
1 Parent(s): e382347

adds harmony in template format

Browse files
Files changed (4) hide show
  1. app.py +7 -20
  2. app_alternative.py +7 -17
  3. app_harmony.py +203 -0
  4. requirements.txt +2 -1
app.py CHANGED
@@ -12,7 +12,7 @@ try:
12
  "openai/gpt-oss-20b",
13
  torch_dtype="auto",
14
  device_map="auto",
15
- attn_implementation="kernel-community/vllm-flash-attention3"
16
  )
17
  tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
18
 
@@ -29,21 +29,6 @@ except Exception as e:
29
  print(f"❌ Error loading model: {e}")
30
  raise e
31
 
32
- def format_messages(messages):
33
- """Format messages into a prompt string"""
34
- formatted = ""
35
- for message in messages:
36
- role = message["role"]
37
- content = message["content"]
38
- if role == "system":
39
- formatted += f"System: {content}\n"
40
- elif role == "user":
41
- formatted += f"User: {content}\n"
42
- elif role == "assistant":
43
- formatted += f"Assistant: {content}\n"
44
- formatted += "Assistant: "
45
- return formatted
46
-
47
  def format_conversation_history(chat_history):
48
  messages = []
49
  for item in chat_history:
@@ -60,9 +45,11 @@ def generate_response(input_data, chat_history, max_new_tokens, system_prompt, t
60
  system_message = [{"role": "system", "content": system_prompt}] if system_prompt else []
61
  processed_history = format_conversation_history(chat_history)
62
  messages = system_message + processed_history + [new_message]
63
-
64
- # Format the prompt
65
- prompt = format_messages(messages)
 
 
66
 
67
  # Create streamer for proper streaming
68
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
@@ -80,7 +67,7 @@ def generate_response(input_data, chat_history, max_new_tokens, system_prompt, t
80
  "use_cache": True
81
  }
82
 
83
- # Tokenize input
84
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
85
 
86
  # Start generation in a separate thread
 
12
  "openai/gpt-oss-20b",
13
  torch_dtype="auto",
14
  device_map="auto",
15
+ attn_implementation="kernels-community/vllm-flash-attention3"
16
  )
17
  tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
18
 
 
29
  print(f"❌ Error loading model: {e}")
30
  raise e
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def format_conversation_history(chat_history):
33
  messages = []
34
  for item in chat_history:
 
45
  system_message = [{"role": "system", "content": system_prompt}] if system_prompt else []
46
  processed_history = format_conversation_history(chat_history)
47
  messages = system_message + processed_history + [new_message]
48
+ prompt = tokenizer.apply_chat_template(
49
+ messages,
50
+ tokenize=False,
51
+ add_generation_prompt=True
52
+ )
53
 
54
  # Create streamer for proper streaming
55
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
 
67
  "use_cache": True
68
  }
69
 
70
+ # Tokenize input using the chat template
71
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
72
 
73
  # Start generation in a separate thread
app_alternative.py CHANGED
@@ -29,21 +29,6 @@ except Exception as e:
29
  print(f"❌ Error loading model: {e}")
30
  raise e
31
 
32
- def format_messages(messages):
33
- """Format messages into a prompt string"""
34
- formatted = ""
35
- for message in messages:
36
- role = message["role"]
37
- content = message["content"]
38
- if role == "system":
39
- formatted += f"System: {content}\n"
40
- elif role == "user":
41
- formatted += f"User: {content}\n"
42
- elif role == "assistant":
43
- formatted += f"Assistant: {content}\n"
44
- formatted += "Assistant: "
45
- return formatted
46
-
47
  def format_conversation_history(chat_history):
48
  messages = []
49
  for item in chat_history:
@@ -61,8 +46,13 @@ def generate_response(input_data, chat_history, max_new_tokens, system_prompt, t
61
  processed_history = format_conversation_history(chat_history)
62
  messages = system_message + processed_history + [new_message]
63
 
64
- # Format the prompt
65
- prompt = format_messages(messages)
 
 
 
 
 
66
 
67
  # Alternative streaming approach with manual chunking
68
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
29
  print(f"❌ Error loading model: {e}")
30
  raise e
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def format_conversation_history(chat_history):
33
  messages = []
34
  for item in chat_history:
 
46
  processed_history = format_conversation_history(chat_history)
47
  messages = system_message + processed_history + [new_message]
48
 
49
+ # Use the model's chat template to format the conversation properly
50
+ # This is crucial for GPT-OSS-20B which expects the Harmony format
51
+ prompt = tokenizer.apply_chat_template(
52
+ messages,
53
+ tokenize=False,
54
+ add_generation_prompt=True
55
+ )
56
 
57
  # Alternative streaming approach with manual chunking
58
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
app_harmony.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
2
+ import torch
3
+ from threading import Thread
4
+ import gradio as gr
5
+ import spaces
6
+ import re
7
+ from peft import PeftModel
8
+
9
+ # Load the base model
10
+ try:
11
+ base_model = AutoModelForCausalLM.from_pretrained(
12
+ "openai/gpt-oss-20b",
13
+ torch_dtype="auto",
14
+ device_map="auto",
15
+ attn_implementation="kernels-community/vllm-flash-attention3"
16
+ )
17
+ tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
18
+
19
+ # Load the LoRA adapter
20
+ try:
21
+ model = PeftModel.from_pretrained(base_model, "Tonic/gpt-oss-20b-multilingual-reasoner")
22
+ print("✅ LoRA model loaded successfully!")
23
+ except Exception as lora_error:
24
+ print(f"⚠️ LoRA adapter failed to load: {lora_error}")
25
+ print("🔄 Falling back to base model...")
26
+ model = base_model
27
+
28
+ except Exception as e:
29
+ print(f"❌ Error loading model: {e}")
30
+ raise e
31
+
32
+ def format_conversation_history(chat_history):
33
+ messages = []
34
+ for item in chat_history:
35
+ role = item["role"]
36
+ content = item["content"]
37
+ if isinstance(content, list):
38
+ content = content[0]["text"] if content and "text" in content[0] else str(content)
39
+ messages.append({"role": role, "content": content})
40
+ return messages
41
+
42
+ def create_harmony_prompt(messages, reasoning_level="medium"):
43
+ """
44
+ Create a proper Harmony format prompt for GPT-OSS-20B
45
+ Based on the Harmony format from https://github.com/openai/harmony
46
+ """
47
+ # Start with system message in Harmony format
48
+ system_content = f"""You are ChatGPT, a large language model trained by OpenAI.
49
+ Knowledge cutoff: 2024-06
50
+ Current date: 2025-01-28
51
+
52
+ Reasoning: {reasoning_level}
53
+
54
+ # Valid channels: analysis, commentary, final. Channel must be included for every message."""
55
+
56
+ # Build the prompt in Harmony format
57
+ prompt_parts = []
58
+
59
+ # Add system message
60
+ prompt_parts.append(f"<|start|>system<|message|>{system_content}<|end|>")
61
+
62
+ # Add conversation messages
63
+ for message in messages:
64
+ role = message["role"]
65
+ content = message["content"]
66
+
67
+ if role == "system":
68
+ # Skip system messages as we already added the main one
69
+ continue
70
+ elif role == "user":
71
+ prompt_parts.append(f"<|start|>user<|message|>{content}<|end|>")
72
+ elif role == "assistant":
73
+ prompt_parts.append(f"<|start|>assistant<|message|>{content}<|end|>")
74
+
75
+ # Add the generation prompt
76
+ prompt_parts.append("<|start|>assistant")
77
+
78
+ return "\n".join(prompt_parts)
79
+
80
+ @spaces.GPU(duration=60)
81
+ def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
82
+ new_message = {"role": "user", "content": input_data}
83
+ system_message = [{"role": "system", "content": system_prompt}] if system_prompt else []
84
+ processed_history = format_conversation_history(chat_history)
85
+ messages = system_message + processed_history + [new_message]
86
+
87
+ # Extract reasoning level from system prompt
88
+ reasoning_level = "medium"
89
+ if "reasoning:" in system_prompt.lower():
90
+ if "high" in system_prompt.lower():
91
+ reasoning_level = "high"
92
+ elif "low" in system_prompt.lower():
93
+ reasoning_level = "low"
94
+
95
+ # Create Harmony format prompt
96
+ prompt = create_harmony_prompt(messages, reasoning_level)
97
+
98
+ # Create streamer for proper streaming
99
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
100
+
101
+ # Prepare generation kwargs
102
+ generation_kwargs = {
103
+ "max_new_tokens": max_new_tokens,
104
+ "do_sample": True,
105
+ "temperature": temperature,
106
+ "top_p": top_p,
107
+ "top_k": top_k,
108
+ "repetition_penalty": repetition_penalty,
109
+ "pad_token_id": tokenizer.eos_token_id,
110
+ "streamer": streamer,
111
+ "use_cache": True
112
+ }
113
+
114
+ # Tokenize input using the Harmony format
115
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
116
+
117
+ # Start generation in a separate thread
118
+ thread = Thread(target=model.generate, kwargs={**inputs, **generation_kwargs})
119
+ thread.start()
120
+
121
+ # Stream the response and parse Harmony format
122
+ current_channel = None
123
+ current_content = ""
124
+ thinking = ""
125
+ final = ""
126
+
127
+ for chunk in streamer:
128
+ current_content += chunk
129
+
130
+ # Parse Harmony format channels
131
+ # Look for channel markers like <|channel|>analysis, <|channel|>commentary, <|channel|>final
132
+ if "<|channel|>" in current_content:
133
+ # Extract channel and content
134
+ parts = current_content.split("<|channel|>")
135
+ if len(parts) >= 2:
136
+ channel_part = parts[1]
137
+ if channel_part.startswith("analysis"):
138
+ current_channel = "analysis"
139
+ content_start = channel_part.find("<|message|>")
140
+ if content_start != -1:
141
+ content = channel_part[content_start + 10:] # length of "<|message|>"
142
+ thinking += content
143
+ elif channel_part.startswith("commentary"):
144
+ current_channel = "commentary"
145
+ content_start = channel_part.find("<|message|>")
146
+ if content_start != -1:
147
+ content = channel_part[content_start + 10:]
148
+ thinking += content
149
+ elif channel_part.startswith("final"):
150
+ current_channel = "final"
151
+ content_start = channel_part.find("<|message|>")
152
+ if content_start != -1:
153
+ content = channel_part[content_start + 10:]
154
+ final += content
155
+
156
+ # Clean up the content for display
157
+ clean_thinking = re.sub(r'^analysis\s*', '', thinking).strip()
158
+ clean_final = final.strip()
159
+
160
+ # Format for display
161
+ if clean_thinking or clean_final:
162
+ formatted = f"<details open><summary>Click to view Thinking Process</summary>\n\n{clean_thinking}\n\n</details>\n\n{clean_final}"
163
+ yield formatted
164
+
165
+ demo = gr.ChatInterface(
166
+ fn=generate_response,
167
+ additional_inputs=[
168
+ gr.Slider(label="Max new tokens", minimum=64, maximum=4096, step=1, value=2048),
169
+ gr.Textbox(
170
+ label="System Prompt",
171
+ value="You are a helpful assistant. Reasoning: medium",
172
+ lines=4,
173
+ placeholder="Change system prompt"
174
+ ),
175
+ gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
176
+ gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
177
+ gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
178
+ gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0)
179
+ ],
180
+ examples=[
181
+ [{"text": "Explain Newton laws clearly and concisely"}],
182
+ [{"text": "Write a Python function to calculate the Fibonacci sequence"}],
183
+ [{"text": "What are the benefits of open weight AI models"}],
184
+ ],
185
+ cache_examples=False,
186
+ type="messages",
187
+ description="""
188
+ # 🙋🏻‍♂️Welcome to 🌟Tonic's gpt-oss-20b Multilingual Reasoner Demo !
189
+ Wait couple of seconds initially. You can adjust reasoning level in the system prompt like "Reasoning: high.
190
+ This version uses the proper Harmony format for better generation quality.
191
+ """,
192
+ fill_height=True,
193
+ textbox=gr.Textbox(
194
+ label="Query Input",
195
+ placeholder="Type your prompt"
196
+ ),
197
+ stop_btn="Stop Generation",
198
+ multimodal=False,
199
+ theme=gr.themes.Soft()
200
+ )
201
+
202
+ if __name__ == "__main__":
203
+ demo.launch(share=True)
requirements.txt CHANGED
@@ -5,4 +5,5 @@ trl
5
  bitsandbytes
6
  triton
7
  accelerate
8
- kernels
 
 
5
  bitsandbytes
6
  triton
7
  accelerate
8
+ kernels
9
+ openai-harmony