Rajan Sharma
Update app.py
40db972 verified
raw
history blame
11.6 kB
import os
import re
from datetime import datetime, timezone
from functools import lru_cache
import gradio as gr
import torch
# Timezone conversion (Python 3.9+ stdlib)
try:
from zoneinfo import ZoneInfo
except Exception:
ZoneInfo = None # graceful fallback to UTC
# Try to import Cohere SDK if present (for hosted path)
try:
import cohere # pip install cohere
_HAS_COHERE = True
except Exception:
_HAS_COHERE = False
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login, HfApi
# -------------------
# Configuration
# -------------------
MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024")
HF_TOKEN = (
os.getenv("HUGGINGFACE_HUB_TOKEN") # official Spaces name
or os.getenv("HF_TOKEN")
)
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
# -------------------
# Helpers (used for connection/status only)
# -------------------
def local_now_str(user_tz: str | None) -> tuple[str, str]:
"""Returns (label, formatted_time). Falls back to UTC if tz missing/invalid."""
label = "UTC"
dt = datetime.now(timezone.utc)
if user_tz and ZoneInfo is not None:
try:
tz = ZoneInfo(user_tz)
dt = datetime.now(tz)
label = user_tz
except Exception:
dt = datetime.now(timezone.utc)
label = "UTC"
return label, dt.strftime("%Y-%m-%d %H:%M:%S")
def pick_dtype_and_map():
if torch.cuda.is_available():
return torch.float16, "auto"
if torch.backends.mps.is_available():
return torch.float16, {"": "mps"}
return torch.float32, "cpu" # CPU path (likely too big for R7B)
def is_identity_query(message: str, history) -> bool:
"""Detects identity questions in current message or most recent user turn."""
patterns = [
r"\bwho\s+are\s+you\b",
r"\bwhat\s+are\s+you\b",
r"\bwhat\s+is\s+your\s+name\b",
r"\bwho\s+is\s+this\b",
r"\bidentify\s+yourself\b",
r"\btell\s+me\s+about\s+yourself\b",
r"\bdescribe\s+yourself\b",
r"\band\s+you\s*\?\b",
r"\byour\s+name\b",
r"\bwho\s+am\s+i\s+chatting\s+with\b",
]
def hit(text: str | None) -> bool:
t = (text or "").strip().lower()
return any(re.search(p, t) for p in patterns)
if hit(message):
return True
if history:
# Gradio history: List[Tuple[user, assistant]]
last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) and history[-1] else None
if hit(last_user):
return True
return False
# -------------------
# Cohere Hosted Path
# -------------------
_co_client = None
if USE_HOSTED_COHERE:
_co_client = cohere.Client(api_key=COHERE_API_KEY)
def _cohere_parse(resp):
# v5+ responses.create
if hasattr(resp, "output_text") and resp.output_text:
return resp.output_text.strip()
if getattr(resp, "message", None) and getattr(resp.message, "content", None):
for p in resp.message.content:
if hasattr(p, "text") and p.text:
return p.text.strip()
# v4 chat
if hasattr(resp, "text") and resp.text:
return resp.text.strip()
return "Sorry, I couldn't parse the response from Cohere."
def cohere_chat(message, history):
try:
# Prefer modern API
try:
msgs = []
for u, a in (history or []):
msgs.append({"role": "user", "content": u})
msgs.append({"role": "assistant", "content": a})
msgs.append({"role": "user", "content": message})
resp = _co_client.responses.create(
model="command-r7b-12-2024",
messages=msgs,
temperature=0.3,
max_tokens=350,
)
except Exception:
# Fallback to older chat API
resp = _co_client.chat(
model="command-r7b-12-2024",
message=message,
temperature=0.3,
max_tokens=350,
)
return _cohere_parse(resp)
except Exception as e:
return f"Error calling Cohere API: {e}"
# -------------------
# Local HF Path
# -------------------
@lru_cache(maxsize=1)
def load_local_model():
if not HF_TOKEN:
raise RuntimeError(
"HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is not set. "
"Either set it, or provide COHERE_API_KEY to use Cohere's hosted API."
)
login(token=HF_TOKEN, add_to_git_credential=False)
dtype, device_map = pick_dtype_and_map()
tok = AutoTokenizer.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
use_fast=True,
model_max_length=4096,
padding_side="left",
trust_remote_code=True,
)
mdl = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
device_map=device_map,
low_cpu_mem_usage=True,
torch_dtype=dtype,
trust_remote_code=True,
)
if mdl.config.eos_token_id is None and tok.eos_token_id is not None:
mdl.config.eos_token_id = tok.eos_token_id
return mdl, tok
def build_inputs(tokenizer, message, history):
msgs = []
for u, a in (history or []):
msgs.append({"role": "user", "content": u})
msgs.append({"role": "assistant", "content": a})
msgs.append({"role": "user", "content": message})
return tokenizer.apply_chat_template(
msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
def local_generate(model, tokenizer, input_ids, max_new_tokens=350):
input_ids = input_ids.to(model.device)
with torch.no_grad():
out = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.3,
top_p=0.9,
repetition_penalty=1.15,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
gen_only = out[0, input_ids.shape[-1]:]
text = tokenizer.decode(gen_only, skip_special_tokens=True)
return text.strip()
# -------------------
# Chat callback (no header/meta in chat replies)
# -------------------
def chat_fn(message, history, user_tz):
try:
# Identity override → return ONLY the brand line
if is_identity_query(message, history):
return "I am ClarityOps, your strategic decision making AI partner."
if USE_HOSTED_COHERE:
return cohere_chat(message, history)
model, tokenizer = load_local_model()
inputs = build_inputs(tokenizer, message, history)
return local_generate(model, tokenizer, inputs, max_new_tokens=350)
except RuntimeError as e:
emsg = str(e)
if "out of memory" in emsg.lower() or "cuda" in emsg.lower():
return "Local load likely OOM. Use a GPU Space or set COHERE_API_KEY to run via Cohere hosted API."
return f"Error during chat: {e}"
except Exception as e:
return f"Error during chat: {e}"
# -------------------
# THEME & STYLES
# -------------------
theme = gr.themes.Soft(
primary_hue="teal",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_lg,
).set(
# Typeface & sizes tuned for executive readability
body_text_size="16px",
heading_text_size="28px",
shadow_drop="0 6px 24px rgba(0,0,0,.06)",
shadow_spread="0 2px 8px rgba(0,0,0,.04)",
)
custom_css = """
:root {
--brand-bg: #f6fbfb;
--brand-card: #ffffff;
--brand-text: #0f172a; /* slate-900 */
--brand-subtle: #475569; /* slate-600 */
--brand-accent: #0d9488; /* teal-600 */
--brand-accent-weak: #99f6e4; /* teal-200 */
--brand-border: #e2e8f0; /* slate-200 */
}
/* Page background and layout */
.gradio-container {
background: var(--brand-bg);
}
/* Title */
h1, .prose h1 {
color: var(--brand-text);
font-weight: 700;
letter-spacing: -0.01em;
margin-bottom: 0.25rem !important;
}
/* Status badge wrapper */
.status-wrap {
display: flex;
align-items: center;
gap: .5rem;
margin-bottom: 0.75rem;
}
/* Badge */
.badge {
display: inline-flex;
align-items: center;
gap: .5rem;
padding: .45rem .75rem;
border-radius: 999px;
border: 1px solid var(--brand-border);
background: #ecfdf5; /* green-50 */
color: #065f46; /* green-800 */
font-weight: 600;
font-size: 14px;
}
/* Description / helper text */
.helper {
color: var(--brand-subtle);
margin: .25rem 0 1rem 0;
}
/* Card polishing */
.block, .gr-box, .gr-panel, .gr-group, .gr-form, .gradio-container .form {
border-radius: 16px !important;
}
/* Chat area spacing */
#chat-root .wrap {
padding: 0 !important;
}
/* Chat bubbles (subtle) */
.message.user {
background: #f8fafc !important; /* slate-50 */
}
.message.bot {
background: #ffffff !important;
}
/* Inputs */
textarea, input, .gr-input {
border-radius: 12px !important;
}
"""
# -------------------
# UI
# -------------------
with gr.Blocks(theme=theme, css=custom_css) as demo:
# Hidden textbox to hold browser timezone (Gradio expects components for outputs)
tz_box = gr.Textbox(visible=False)
# On load, capture browser timezone via JS and write it into tz_box
demo.load(
fn=lambda tz: tz, # echo JS value to Python
inputs=[tz_box], # 1 input required for lambda
outputs=[tz_box], # write into same hidden box
js="() => Intl.DateTimeFormat().resolvedOptions().timeZone"
)
# Model status (auto, no button)
def model_status(_user_tz):
try:
if USE_HOSTED_COHERE:
return (
'<div class="status-wrap">'
'<span class="badge">✅ Connected • Cohere API — model: <strong>command-r7b-12-2024</strong></span>'
"</div>"
)
api = HfApi(token=HF_TOKEN)
mi = api.model_info(MODEL_ID)
return (
'<div class="status-wrap">'
f'<span class="badge">✅ Connected • Local HF — model: <strong>{mi.modelId}</strong></span>'
"</div>"
)
except Exception as e:
return (
'<div class="status-wrap">'
f'<span class="badge" style="background:#fff7ed;color:#9a3412;border-color:#fed7aa;">'
f'⚠️ Connection Issue — {str(e)}'
'</span></div>'
)
# Header
gr.Markdown("# Medical Decision Support AI")
# Status line (renders HTML badge)
status_line = gr.HTML("<div class='status-wrap'><span class='badge'>Connecting…</span></div>")
demo.load(fn=model_status, inputs=[tz_box], outputs=[status_line])
# Subtle helper text
gr.Markdown(
"<div class='helper'>Designed for healthcare executives: concise, reliable decision support. "
"First response may take a moment while the model warms up.</div>"
)
# Chat
chat = gr.ChatInterface(
fn=chat_fn,
type="messages",
additional_inputs=[tz_box], # pass timezone into chat_fn
description="",
examples=[
["What are the symptoms of hypertension?", ""],
["What are common drug interactions with aspirin?", ""],
["What are the warning signs of diabetes?", ""],
],
cache_examples=True,
elem_id="chat-root",
)
if __name__ == "__main__":
demo.launch()