spotless-chat / app.py
rileyseaburg's picture
Update to 50k step model
2630728 verified
"""
Spotless Bin Co Customer Service - HuggingFace Space
A specialized customer service chatbot for trash bin cleaning.
"""
import os
import sys
import torch
import gradio as gr
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
from dataclasses import dataclass
# =============================================================================
# Model Architecture (exact copy from smelter)
# =============================================================================
@dataclass
class ModelConfig:
vocab_size: int = 32000
hidden_dim: int = 768
num_layers: int = 12
num_heads: int = 12
num_kv_heads: int = 4
head_dim: int = 64
mlp_hidden_dim: int = 2048
max_seq_len: int = 2048
dropout: float = 0.0
rope_base: float = 1000000.0
use_bitlinear: bool = True
use_bitlinear_ffn: bool = False # MLP uses regular Linear
norm_eps: float = 1e-6
tie_word_embeddings: bool = True
use_qk_norm: bool = True
attn_logit_soft_cap: float = 50.0
final_logit_soft_cap: float = 30.0
@property
def num_kv_groups(self) -> int:
return self.num_heads // self.num_kv_heads
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))
def forward(self, x):
rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
return x * rms * self.weight
class BitLinear(torch.nn.Module):
"""BitNet b1.58 linear layer with ternary weights {-1, 0, +1}."""
def __init__(self, in_features: int, out_features: int, bias: bool = False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.nn.Parameter(torch.empty(out_features, in_features))
self.eps = 1e-8
torch.nn.init.kaiming_uniform_(self.weight)
def forward(self, x):
# Activation quantization (per-token absmax)
x_abs_max = x.abs().max(dim=-1, keepdim=True)[0].clamp(min=self.eps)
x_quant = x / x_abs_max
# Weight quantization to ternary {-1, 0, +1}
w = self.weight
alpha = w.abs().mean()
w_quant = torch.sign(w) * (w.abs() > alpha * 0.5).float()
# Compute output
output = torch.nn.functional.linear(x_quant, w_quant)
# Scale by weight scale and restore activation scale
return output * alpha * x_abs_max
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 1000000.0):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, x, seq_len: int):
t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Attention(torch.nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.num_heads = config.num_heads
self.num_kv_heads = config.num_kv_heads
self.head_dim = config.head_dim
self.num_kv_groups = config.num_kv_groups
LinearClass = BitLinear if config.use_bitlinear else torch.nn.Linear
self.q_proj = LinearClass(config.hidden_dim, config.num_heads * config.head_dim, bias=False)
self.k_proj = LinearClass(config.hidden_dim, config.num_kv_heads * config.head_dim, bias=False)
self.v_proj = LinearClass(config.hidden_dim, config.num_kv_heads * config.head_dim, bias=False)
self.o_proj = LinearClass(config.num_heads * config.head_dim, config.hidden_dim, bias=False)
if config.use_qk_norm:
self.q_norm = RMSNorm(config.head_dim, eps=config.norm_eps)
self.k_norm = RMSNorm(config.head_dim, eps=config.norm_eps)
self.rotary = RotaryEmbedding(config.head_dim, config.max_seq_len, config.rope_base)
def forward(self, x, mask=None):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
if self.config.use_qk_norm:
q = self.q_norm(q)
k = self.k_norm(k)
cos, sin = self.rotary(x, T)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
if self.num_kv_groups > 1:
k = k.repeat_interleave(self.num_kv_groups, dim=1)
v = v.repeat_interleave(self.num_kv_groups, dim=1)
scale = self.head_dim ** -0.5
attn = torch.matmul(q, k.transpose(-2, -1)) * scale
if self.config.attn_logit_soft_cap:
cap = self.config.attn_logit_soft_cap
attn = cap * torch.tanh(attn / cap)
causal_mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
attn = attn.masked_fill(causal_mask, float('-inf'))
attn = torch.softmax(attn, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, -1)
return self.o_proj(out)
class MLP(torch.nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.gate_proj = torch.nn.Linear(config.hidden_dim, config.mlp_hidden_dim, bias=False)
self.up_proj = torch.nn.Linear(config.hidden_dim, config.mlp_hidden_dim, bias=False)
self.down_proj = torch.nn.Linear(config.mlp_hidden_dim, config.hidden_dim, bias=False)
def forward(self, x):
return self.down_proj(torch.nn.functional.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(torch.nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.input_layernorm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
self.self_attn = Attention(config)
self.post_attention_layernorm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
self.mlp = MLP(config)
def forward(self, x):
x = x + self.self_attn(self.input_layernorm(x))
x = x + self.mlp(self.post_attention_layernorm(x))
return x
class StudentLLM(torch.nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_dim)
self.layers = torch.nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
self.norm = RMSNorm(config.hidden_dim, eps=config.norm_eps)
self.lm_head = torch.nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.embed_tokens.weight
def forward(self, input_ids):
x = self.embed_tokens(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
logits = self.lm_head(x)
if self.config.final_logit_soft_cap:
cap = self.config.final_logit_soft_cap
logits = cap * torch.tanh(logits / cap)
return logits
# =============================================================================
# Global Model Loading
# =============================================================================
model = None
tokenizer = None
def load_model():
global model, tokenizer
if model is not None:
return
print("Loading Spotless Customer Service Model...")
config = ModelConfig()
model = StudentLLM(config)
print("Downloading weights...")
weights_path = hf_hub_download(
repo_id="rileyseaburg/distillix-spotless",
filename="spotless-50k-final.pt"
)
print("Loading weights...")
ckpt = torch.load(weights_path, map_location='cpu', weights_only=False)
if 'model_state_dict' in ckpt:
state_dict = ckpt['model_state_dict']
else:
state_dict = ckpt
model.load_state_dict(state_dict, strict=False)
model.eval()
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
print("Model ready!")
@torch.inference_mode()
def generate(message, history, max_tokens=100, temperature=0.7):
"""Chat generation with history support."""
try:
load_model()
# Format as customer service conversation
prompt = f"Customer: {message}\nAgent:"
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
generated = input_ids.clone()
for _ in range(int(max_tokens)):
if generated.shape[1] >= 512:
break
logits = model(generated)
next_token_logits = logits[:, -1, :]
if temperature > 0:
next_token_logits = next_token_logits / temperature
# Top-p sampling
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > 0.9
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat([generated, next_token], dim=-1)
# Stop conditions
if next_token.item() == tokenizer.eos_token_id:
break
# Stop at newline (end of agent turn)
decoded_so_far = tokenizer.decode(generated[0], skip_special_tokens=True)
if "Agent:" in decoded_so_far:
agent_part = decoded_so_far.split("Agent:")[-1]
if "\n" in agent_part and len(agent_part.strip()) > 10:
break
output = tokenizer.decode(generated[0], skip_special_tokens=True)
# Extract just the agent's response
if "Agent:" in output:
response = output.split("Agent:")[-1].strip()
# Stop at next Customer: if present
if "Customer:" in response:
response = response.split("Customer:")[0].strip()
# Just first line
response = response.split("\n")[0].strip()
else:
response = output
return response
except Exception as e:
return f"I apologize, I'm having technical difficulties. Error: {str(e)}"
# =============================================================================
# Gradio Interface
# =============================================================================
CSS = """
.gradio-container {
max-width: 800px !important;
}
"""
with gr.Blocks(title="Spotless Bin Co - Customer Service", css=CSS) as demo:
gr.Markdown("""
# Spotless Bin Co - Customer Service
Welcome! I'm your AI assistant for **Spotless Bin Co**, your local trash bin cleaning service.
I can help you with:
- Scheduling bin cleaning appointments
- Pricing and service information
- Billing questions
- Cancellations and account changes
- General inquiries
---
*Powered by Distillix 100M BitNet - a 100M parameter AI running locally*
""")
chatbot = gr.Chatbot(
label="Chat",
height=400,
show_label=False,
)
with gr.Row():
msg = gr.Textbox(
label="Your Message",
placeholder="Type your question here...",
scale=4,
show_label=False,
)
send_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Accordion("Settings", open=False):
max_tokens = gr.Slider(32, 200, value=100, step=8, label="Max Response Length")
temperature = gr.Slider(0.1, 1.2, value=0.7, step=0.1, label="Temperature")
gr.Examples(
examples=[
"Hi, I need to schedule a bin cleaning",
"How much does your service cost?",
"I want to cancel my subscription",
"My bin wasn't cleaned yesterday",
"What areas do you service?",
"Do you clean recycling bins too?",
],
inputs=msg,
label="Quick Questions"
)
def respond(message, chat_history, max_tok, temp):
if not message.strip():
return "", chat_history
response = generate(message, chat_history, max_tok, temp)
chat_history.append((message, response))
return "", chat_history
msg.submit(respond, [msg, chatbot, max_tokens, temperature], [msg, chatbot])
send_btn.click(respond, [msg, chatbot, max_tokens, temperature], [msg, chatbot])
gr.Markdown("""
---
**About this AI**: This is a 100M parameter BitNet model trained specifically for Spotless Bin Co customer service.
It runs entirely on CPU and uses ternary weights {-1, 0, +1} for extreme efficiency.
[Model](https://huggingface.co/rileyseaburg/distillix-spotless) | [GitHub](https://github.com/rileyseaburg/distillix)
""")
if __name__ == "__main__":
demo.launch(show_error=True)