messaging / app.py
Alifjo123's picture
Update app.py
c2a02b3 verified
import time
import html
from datetime import datetime
import gradio as gr
import torch
from transformers import RobertaTokenizerFast, RobertaForSequenceClassification
# ---------------------------
# Load your model (conversation learning style)
# ---------------------------
MODEL_PATH = "Alifjo123/robertaBase_messaging_100k" # your HuggingFace model
tokenizer = RobertaTokenizerFast.from_pretrained(MODEL_PATH)
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# ---------------------------
# Util: classify a single text
# ---------------------------
def predict_label(text: str):
inputs = tokenizer(
text,
truncation=True,
padding=True,
max_length=512,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]
pred = int(probs[1] > probs[0]) # 1=unsafe, 0=safe
return pred, probs[0] * 100.0, probs[1] * 100.0
# ---------------------------
# Render chat (HTML) from state
# ---------------------------
def render_chat(messages):
html_parts = ['<div class="chat">']
for m in messages:
side = "right" if m["role"] == "User A" else "left"
bubble_class = "bubble-a" if side == "right" else "bubble-b"
label = "Unsafe ❌" if m["pred"] == 1 else "Safe ✅"
safe = f'{m["safe"]:.1f}%'
unsafe = f'{m["unsafe"]:.1f}%'
ts = datetime.fromtimestamp(m["ts"]).strftime("%H:%M")
html_parts.append(f"""
<div class="row {side}">
<div class="bubble {bubble_class}">
<div class="meta">
<span class="name">{m["role"]}</span>
<span class="time">{ts}</span>
</div>
<div class="text">{html.escape(m["text"])}</div>
<div class="badges">
<span class="chip {'chip-unsafe' if m['pred']==1 else 'chip-safe'}">{label}</span>
<span class="probs">Safe {safe} · Unsafe {unsafe}</span>
</div>
</div>
</div>
""")
html_parts.append("</div>")
return "\n".join(html_parts)
# ---------------------------
# Gradio callbacks
# ---------------------------
def send_message(role, message, messages):
if messages is None:
messages = []
message = message.strip()
if not message:
return gr.update(), "", messages # ignore empty
pred, safe_pct, unsafe_pct = predict_label(message)
messages.append({
"role": role,
"text": message,
"pred": pred,
"safe": safe_pct,
"unsafe": unsafe_pct,
"ts": time.time()
})
return render_chat(messages), "", messages
def clear_chat():
return render_chat([]), []
# ---------------------------
# Custom CSS (Pink + Blue ONLY)
# ---------------------------
CSS = """
* { box-sizing: border-box; }
:root {
--bg-gradient: linear-gradient(135deg, #ff99cc, #66ccff);
--bubble-a: #ff66b2;
--bubble-b: #3399ff;
--text-light: #f9f9f9;
--chip-safe: #00e676;
--chip-unsafe: #ff5252;
}
body {
background: var(--bg-gradient);
font-family: 'Segoe UI', sans-serif;
color: var(--text-light);
}
.gradio-container { max-width: 800px !important; margin: 0 auto; }
.header {
text-align:center;
padding:16px;
background: rgba(0,0,0,.35);
border-radius: 16px;
font-size:22px; font-weight:700;
margin-bottom: 12px;
color: #fff;
box-shadow: 0 6px 18px rgba(0,0,0,.25);
}
.panel {
background: rgba(0,0,0,0.4);
border-radius: 20px;
overflow: hidden;
box-shadow: 0 8px 28px rgba(0,0,0,.35);
backdrop-filter: blur(8px);
}
.chat {
padding: 16px;
height: 480px;
overflow-y: auto;
}
.row { display:flex; margin: 10px 0; }
.row.right { justify-content: flex-end; }
.row.left { justify-content: flex-start; }
.bubble {
max-width: 70%;
padding: 10px 14px;
border-radius: 16px;
color: var(--text-light);
box-shadow: 0 4px 12px rgba(0,0,0,.25);
animation: fadeIn .25s ease-out;
}
.bubble-a { background: var(--bubble-a); }
.bubble-b { background: var(--bubble-b); }
.meta {
display:flex; justify-content: space-between;
font-size: 12px; opacity:.9; margin-bottom: 4px;
}
.text { white-space: pre-wrap; line-height: 1.35; font-size: 14.5px; }
.badges { margin-top: 6px; font-size: 12px; opacity:.95; display:flex; gap:8px; }
.chip { padding:2px 8px; border-radius:12px; font-weight:600; }
.chip-safe { background:#004d40; color:var(--chip-safe); }
.chip-unsafe { background:#3d0000; color:var(--chip-unsafe); }
.controls {
padding: 10px;
display:flex; align-items:center; gap:10px;
background: rgba(0,0,0,.35);
}
.controls .textbox { flex:1; }
.footer-note {
font-size: 12px; text-align:center; margin-top: 8px; opacity:.8; color:#eee;
}
@keyframes fadeIn { from{opacity:0;transform:translateY(6px);} to{opacity:1;transform:translateY(0);} }
"""
# ---------------------------
# Build UI
# ---------------------------
with gr.Blocks(css=CSS) as demo:
with gr.Group(elem_classes="panel"):
gr.HTML('<div class="header">💬 Let\'s Chat</div>')
chat_html = gr.HTML(render_chat([]), elem_id="chat")
messages_state = gr.State([])
with gr.Row(elem_classes="controls"):
role = gr.Dropdown(["User A", "User B"], value="User A", label="Role")
msg = gr.Textbox(placeholder="Type a message…", label=None, lines=2, elem_classes="textbox")
send = gr.Button("Send", variant="primary")
clear = gr.Button("Clear", variant="secondary")
send.click(send_message, inputs=[role, msg, messages_state], outputs=[chat_html, msg, messages_state])
msg.submit(send_message, inputs=[role, msg, messages_state], outputs=[chat_html, msg, messages_state])
clear.click(clear_chat, outputs=[chat_html, messages_state])
gr.Markdown('<div class="footer-note">Model: <code>Alifjo123/robertaBase_messaging_100k</code></div>')
if __name__ == "__main__":
demo.launch()