Spaces:
Sleeping
Sleeping
File size: 11,614 Bytes
40db972 b23412f 40db972 5e87cca 64972fd 6c7162d b7a949a 40db972 b7a949a d46abd4 40db972 0b1c3ed 40db972 0b1c3ed 40db972 0b1c3ed 40db972 0b1c3ed 40db972 b7a949a 1c47f55 40db972 0b1c3ed 1c47f55 40db972 b23412f 40db972 b23412f 11a5624 40db972 11a5624 1c47f55 40db972 68f033a 149cfa7 b192119 0b1c3ed 40db972 68f033a 11a5624 6c7162d 0b1c3ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 |
import os
import re
from datetime import datetime, timezone
from functools import lru_cache
import gradio as gr
import torch
# Timezone conversion (Python 3.9+ stdlib)
try:
from zoneinfo import ZoneInfo
except Exception:
ZoneInfo = None # graceful fallback to UTC
# Try to import Cohere SDK if present (for hosted path)
try:
import cohere # pip install cohere
_HAS_COHERE = True
except Exception:
_HAS_COHERE = False
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login, HfApi
# -------------------
# Configuration
# -------------------
MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024")
HF_TOKEN = (
os.getenv("HUGGINGFACE_HUB_TOKEN") # official Spaces name
or os.getenv("HF_TOKEN")
)
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
# -------------------
# Helpers (used for connection/status only)
# -------------------
def local_now_str(user_tz: str | None) -> tuple[str, str]:
"""Returns (label, formatted_time). Falls back to UTC if tz missing/invalid."""
label = "UTC"
dt = datetime.now(timezone.utc)
if user_tz and ZoneInfo is not None:
try:
tz = ZoneInfo(user_tz)
dt = datetime.now(tz)
label = user_tz
except Exception:
dt = datetime.now(timezone.utc)
label = "UTC"
return label, dt.strftime("%Y-%m-%d %H:%M:%S")
def pick_dtype_and_map():
if torch.cuda.is_available():
return torch.float16, "auto"
if torch.backends.mps.is_available():
return torch.float16, {"": "mps"}
return torch.float32, "cpu" # CPU path (likely too big for R7B)
def is_identity_query(message: str, history) -> bool:
"""Detects identity questions in current message or most recent user turn."""
patterns = [
r"\bwho\s+are\s+you\b",
r"\bwhat\s+are\s+you\b",
r"\bwhat\s+is\s+your\s+name\b",
r"\bwho\s+is\s+this\b",
r"\bidentify\s+yourself\b",
r"\btell\s+me\s+about\s+yourself\b",
r"\bdescribe\s+yourself\b",
r"\band\s+you\s*\?\b",
r"\byour\s+name\b",
r"\bwho\s+am\s+i\s+chatting\s+with\b",
]
def hit(text: str | None) -> bool:
t = (text or "").strip().lower()
return any(re.search(p, t) for p in patterns)
if hit(message):
return True
if history:
# Gradio history: List[Tuple[user, assistant]]
last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) and history[-1] else None
if hit(last_user):
return True
return False
# -------------------
# Cohere Hosted Path
# -------------------
_co_client = None
if USE_HOSTED_COHERE:
_co_client = cohere.Client(api_key=COHERE_API_KEY)
def _cohere_parse(resp):
# v5+ responses.create
if hasattr(resp, "output_text") and resp.output_text:
return resp.output_text.strip()
if getattr(resp, "message", None) and getattr(resp.message, "content", None):
for p in resp.message.content:
if hasattr(p, "text") and p.text:
return p.text.strip()
# v4 chat
if hasattr(resp, "text") and resp.text:
return resp.text.strip()
return "Sorry, I couldn't parse the response from Cohere."
def cohere_chat(message, history):
try:
# Prefer modern API
try:
msgs = []
for u, a in (history or []):
msgs.append({"role": "user", "content": u})
msgs.append({"role": "assistant", "content": a})
msgs.append({"role": "user", "content": message})
resp = _co_client.responses.create(
model="command-r7b-12-2024",
messages=msgs,
temperature=0.3,
max_tokens=350,
)
except Exception:
# Fallback to older chat API
resp = _co_client.chat(
model="command-r7b-12-2024",
message=message,
temperature=0.3,
max_tokens=350,
)
return _cohere_parse(resp)
except Exception as e:
return f"Error calling Cohere API: {e}"
# -------------------
# Local HF Path
# -------------------
@lru_cache(maxsize=1)
def load_local_model():
if not HF_TOKEN:
raise RuntimeError(
"HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is not set. "
"Either set it, or provide COHERE_API_KEY to use Cohere's hosted API."
)
login(token=HF_TOKEN, add_to_git_credential=False)
dtype, device_map = pick_dtype_and_map()
tok = AutoTokenizer.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
use_fast=True,
model_max_length=4096,
padding_side="left",
trust_remote_code=True,
)
mdl = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
device_map=device_map,
low_cpu_mem_usage=True,
torch_dtype=dtype,
trust_remote_code=True,
)
if mdl.config.eos_token_id is None and tok.eos_token_id is not None:
mdl.config.eos_token_id = tok.eos_token_id
return mdl, tok
def build_inputs(tokenizer, message, history):
msgs = []
for u, a in (history or []):
msgs.append({"role": "user", "content": u})
msgs.append({"role": "assistant", "content": a})
msgs.append({"role": "user", "content": message})
return tokenizer.apply_chat_template(
msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
def local_generate(model, tokenizer, input_ids, max_new_tokens=350):
input_ids = input_ids.to(model.device)
with torch.no_grad():
out = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.3,
top_p=0.9,
repetition_penalty=1.15,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
gen_only = out[0, input_ids.shape[-1]:]
text = tokenizer.decode(gen_only, skip_special_tokens=True)
return text.strip()
# -------------------
# Chat callback (no header/meta in chat replies)
# -------------------
def chat_fn(message, history, user_tz):
try:
# Identity override → return ONLY the brand line
if is_identity_query(message, history):
return "I am ClarityOps, your strategic decision making AI partner."
if USE_HOSTED_COHERE:
return cohere_chat(message, history)
model, tokenizer = load_local_model()
inputs = build_inputs(tokenizer, message, history)
return local_generate(model, tokenizer, inputs, max_new_tokens=350)
except RuntimeError as e:
emsg = str(e)
if "out of memory" in emsg.lower() or "cuda" in emsg.lower():
return "Local load likely OOM. Use a GPU Space or set COHERE_API_KEY to run via Cohere hosted API."
return f"Error during chat: {e}"
except Exception as e:
return f"Error during chat: {e}"
# -------------------
# THEME & STYLES
# -------------------
theme = gr.themes.Soft(
primary_hue="teal",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_lg,
).set(
# Typeface & sizes tuned for executive readability
body_text_size="16px",
heading_text_size="28px",
shadow_drop="0 6px 24px rgba(0,0,0,.06)",
shadow_spread="0 2px 8px rgba(0,0,0,.04)",
)
custom_css = """
:root {
--brand-bg: #f6fbfb;
--brand-card: #ffffff;
--brand-text: #0f172a; /* slate-900 */
--brand-subtle: #475569; /* slate-600 */
--brand-accent: #0d9488; /* teal-600 */
--brand-accent-weak: #99f6e4; /* teal-200 */
--brand-border: #e2e8f0; /* slate-200 */
}
/* Page background and layout */
.gradio-container {
background: var(--brand-bg);
}
/* Title */
h1, .prose h1 {
color: var(--brand-text);
font-weight: 700;
letter-spacing: -0.01em;
margin-bottom: 0.25rem !important;
}
/* Status badge wrapper */
.status-wrap {
display: flex;
align-items: center;
gap: .5rem;
margin-bottom: 0.75rem;
}
/* Badge */
.badge {
display: inline-flex;
align-items: center;
gap: .5rem;
padding: .45rem .75rem;
border-radius: 999px;
border: 1px solid var(--brand-border);
background: #ecfdf5; /* green-50 */
color: #065f46; /* green-800 */
font-weight: 600;
font-size: 14px;
}
/* Description / helper text */
.helper {
color: var(--brand-subtle);
margin: .25rem 0 1rem 0;
}
/* Card polishing */
.block, .gr-box, .gr-panel, .gr-group, .gr-form, .gradio-container .form {
border-radius: 16px !important;
}
/* Chat area spacing */
#chat-root .wrap {
padding: 0 !important;
}
/* Chat bubbles (subtle) */
.message.user {
background: #f8fafc !important; /* slate-50 */
}
.message.bot {
background: #ffffff !important;
}
/* Inputs */
textarea, input, .gr-input {
border-radius: 12px !important;
}
"""
# -------------------
# UI
# -------------------
with gr.Blocks(theme=theme, css=custom_css) as demo:
# Hidden textbox to hold browser timezone (Gradio expects components for outputs)
tz_box = gr.Textbox(visible=False)
# On load, capture browser timezone via JS and write it into tz_box
demo.load(
fn=lambda tz: tz, # echo JS value to Python
inputs=[tz_box], # 1 input required for lambda
outputs=[tz_box], # write into same hidden box
js="() => Intl.DateTimeFormat().resolvedOptions().timeZone"
)
# Model status (auto, no button)
def model_status(_user_tz):
try:
if USE_HOSTED_COHERE:
return (
'<div class="status-wrap">'
'<span class="badge">✅ Connected • Cohere API — model: <strong>command-r7b-12-2024</strong></span>'
"</div>"
)
api = HfApi(token=HF_TOKEN)
mi = api.model_info(MODEL_ID)
return (
'<div class="status-wrap">'
f'<span class="badge">✅ Connected • Local HF — model: <strong>{mi.modelId}</strong></span>'
"</div>"
)
except Exception as e:
return (
'<div class="status-wrap">'
f'<span class="badge" style="background:#fff7ed;color:#9a3412;border-color:#fed7aa;">'
f'⚠️ Connection Issue — {str(e)}'
'</span></div>'
)
# Header
gr.Markdown("# Medical Decision Support AI")
# Status line (renders HTML badge)
status_line = gr.HTML("<div class='status-wrap'><span class='badge'>Connecting…</span></div>")
demo.load(fn=model_status, inputs=[tz_box], outputs=[status_line])
# Subtle helper text
gr.Markdown(
"<div class='helper'>Designed for healthcare executives: concise, reliable decision support. "
"First response may take a moment while the model warms up.</div>"
)
# Chat
chat = gr.ChatInterface(
fn=chat_fn,
type="messages",
additional_inputs=[tz_box], # pass timezone into chat_fn
description="",
examples=[
["What are the symptoms of hypertension?", ""],
["What are common drug interactions with aspirin?", ""],
["What are the warning signs of diabetes?", ""],
],
cache_examples=True,
elem_id="chat-root",
)
if __name__ == "__main__":
demo.launch()
|