|
|
import time |
|
|
import html |
|
|
from datetime import datetime |
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import RobertaTokenizerFast, RobertaForSequenceClassification |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_PATH = "Alifjo123/robertaBase_messaging_100k" |
|
|
tokenizer = RobertaTokenizerFast.from_pretrained(MODEL_PATH) |
|
|
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH) |
|
|
model.eval() |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_label(text: str): |
|
|
inputs = tokenizer( |
|
|
text, |
|
|
truncation=True, |
|
|
padding=True, |
|
|
max_length=512, |
|
|
return_tensors="pt" |
|
|
).to(device) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
probs = torch.softmax(logits, dim=-1).cpu().numpy()[0] |
|
|
pred = int(probs[1] > probs[0]) |
|
|
return pred, probs[0] * 100.0, probs[1] * 100.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def render_chat(messages): |
|
|
html_parts = ['<div class="chat">'] |
|
|
for m in messages: |
|
|
side = "right" if m["role"] == "User A" else "left" |
|
|
bubble_class = "bubble-a" if side == "right" else "bubble-b" |
|
|
label = "Unsafe ❌" if m["pred"] == 1 else "Safe ✅" |
|
|
safe = f'{m["safe"]:.1f}%' |
|
|
unsafe = f'{m["unsafe"]:.1f}%' |
|
|
ts = datetime.fromtimestamp(m["ts"]).strftime("%H:%M") |
|
|
|
|
|
html_parts.append(f""" |
|
|
<div class="row {side}"> |
|
|
<div class="bubble {bubble_class}"> |
|
|
<div class="meta"> |
|
|
<span class="name">{m["role"]}</span> |
|
|
<span class="time">{ts}</span> |
|
|
</div> |
|
|
<div class="text">{html.escape(m["text"])}</div> |
|
|
<div class="badges"> |
|
|
<span class="chip {'chip-unsafe' if m['pred']==1 else 'chip-safe'}">{label}</span> |
|
|
<span class="probs">Safe {safe} · Unsafe {unsafe}</span> |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
""") |
|
|
html_parts.append("</div>") |
|
|
return "\n".join(html_parts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def send_message(role, message, messages): |
|
|
if messages is None: |
|
|
messages = [] |
|
|
message = message.strip() |
|
|
if not message: |
|
|
return gr.update(), "", messages |
|
|
|
|
|
pred, safe_pct, unsafe_pct = predict_label(message) |
|
|
messages.append({ |
|
|
"role": role, |
|
|
"text": message, |
|
|
"pred": pred, |
|
|
"safe": safe_pct, |
|
|
"unsafe": unsafe_pct, |
|
|
"ts": time.time() |
|
|
}) |
|
|
return render_chat(messages), "", messages |
|
|
|
|
|
def clear_chat(): |
|
|
return render_chat([]), [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
CSS = """ |
|
|
* { box-sizing: border-box; } |
|
|
:root { |
|
|
--bg-gradient: linear-gradient(135deg, #ff99cc, #66ccff); |
|
|
--bubble-a: #ff66b2; |
|
|
--bubble-b: #3399ff; |
|
|
--text-light: #f9f9f9; |
|
|
--chip-safe: #00e676; |
|
|
--chip-unsafe: #ff5252; |
|
|
} |
|
|
|
|
|
body { |
|
|
background: var(--bg-gradient); |
|
|
font-family: 'Segoe UI', sans-serif; |
|
|
color: var(--text-light); |
|
|
} |
|
|
|
|
|
.gradio-container { max-width: 800px !important; margin: 0 auto; } |
|
|
|
|
|
.header { |
|
|
text-align:center; |
|
|
padding:16px; |
|
|
background: rgba(0,0,0,.35); |
|
|
border-radius: 16px; |
|
|
font-size:22px; font-weight:700; |
|
|
margin-bottom: 12px; |
|
|
color: #fff; |
|
|
box-shadow: 0 6px 18px rgba(0,0,0,.25); |
|
|
} |
|
|
|
|
|
.panel { |
|
|
background: rgba(0,0,0,0.4); |
|
|
border-radius: 20px; |
|
|
overflow: hidden; |
|
|
box-shadow: 0 8px 28px rgba(0,0,0,.35); |
|
|
backdrop-filter: blur(8px); |
|
|
} |
|
|
|
|
|
.chat { |
|
|
padding: 16px; |
|
|
height: 480px; |
|
|
overflow-y: auto; |
|
|
} |
|
|
|
|
|
.row { display:flex; margin: 10px 0; } |
|
|
.row.right { justify-content: flex-end; } |
|
|
.row.left { justify-content: flex-start; } |
|
|
|
|
|
.bubble { |
|
|
max-width: 70%; |
|
|
padding: 10px 14px; |
|
|
border-radius: 16px; |
|
|
color: var(--text-light); |
|
|
box-shadow: 0 4px 12px rgba(0,0,0,.25); |
|
|
animation: fadeIn .25s ease-out; |
|
|
} |
|
|
.bubble-a { background: var(--bubble-a); } |
|
|
.bubble-b { background: var(--bubble-b); } |
|
|
|
|
|
.meta { |
|
|
display:flex; justify-content: space-between; |
|
|
font-size: 12px; opacity:.9; margin-bottom: 4px; |
|
|
} |
|
|
.text { white-space: pre-wrap; line-height: 1.35; font-size: 14.5px; } |
|
|
.badges { margin-top: 6px; font-size: 12px; opacity:.95; display:flex; gap:8px; } |
|
|
.chip { padding:2px 8px; border-radius:12px; font-weight:600; } |
|
|
.chip-safe { background:#004d40; color:var(--chip-safe); } |
|
|
.chip-unsafe { background:#3d0000; color:var(--chip-unsafe); } |
|
|
|
|
|
.controls { |
|
|
padding: 10px; |
|
|
display:flex; align-items:center; gap:10px; |
|
|
background: rgba(0,0,0,.35); |
|
|
} |
|
|
.controls .textbox { flex:1; } |
|
|
|
|
|
.footer-note { |
|
|
font-size: 12px; text-align:center; margin-top: 8px; opacity:.8; color:#eee; |
|
|
} |
|
|
|
|
|
@keyframes fadeIn { from{opacity:0;transform:translateY(6px);} to{opacity:1;transform:translateY(0);} } |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(css=CSS) as demo: |
|
|
with gr.Group(elem_classes="panel"): |
|
|
gr.HTML('<div class="header">💬 Let\'s Chat</div>') |
|
|
|
|
|
chat_html = gr.HTML(render_chat([]), elem_id="chat") |
|
|
messages_state = gr.State([]) |
|
|
|
|
|
with gr.Row(elem_classes="controls"): |
|
|
role = gr.Dropdown(["User A", "User B"], value="User A", label="Role") |
|
|
msg = gr.Textbox(placeholder="Type a message…", label=None, lines=2, elem_classes="textbox") |
|
|
send = gr.Button("Send", variant="primary") |
|
|
clear = gr.Button("Clear", variant="secondary") |
|
|
|
|
|
send.click(send_message, inputs=[role, msg, messages_state], outputs=[chat_html, msg, messages_state]) |
|
|
msg.submit(send_message, inputs=[role, msg, messages_state], outputs=[chat_html, msg, messages_state]) |
|
|
clear.click(clear_chat, outputs=[chat_html, messages_state]) |
|
|
|
|
|
gr.Markdown('<div class="footer-note">Model: <code>Alifjo123/robertaBase_messaging_100k</code></div>') |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|