import os import re from datetime import datetime, timezone from functools import lru_cache import gradio as gr import torch # Timezone conversion (Python 3.9+ stdlib) try: from zoneinfo import ZoneInfo except Exception: ZoneInfo = None # graceful fallback to UTC # Try to import Cohere SDK if present (for hosted path) try: import cohere # pip install cohere _HAS_COHERE = True except Exception: _HAS_COHERE = False from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import login, HfApi # ------------------- # Configuration # ------------------- MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024") HF_TOKEN = ( os.getenv("HUGGINGFACE_HUB_TOKEN") # official Spaces name or os.getenv("HF_TOKEN") ) COHERE_API_KEY = os.getenv("COHERE_API_KEY") USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE) # ------------------- # Helpers (used for connection/status only) # ------------------- def local_now_str(user_tz: str | None) -> tuple[str, str]: """Returns (label, formatted_time). Falls back to UTC if tz missing/invalid.""" label = "UTC" dt = datetime.now(timezone.utc) if user_tz and ZoneInfo is not None: try: tz = ZoneInfo(user_tz) dt = datetime.now(tz) label = user_tz except Exception: dt = datetime.now(timezone.utc) label = "UTC" return label, dt.strftime("%Y-%m-%d %H:%M:%S") def pick_dtype_and_map(): if torch.cuda.is_available(): return torch.float16, "auto" if torch.backends.mps.is_available(): return torch.float16, {"": "mps"} return torch.float32, "cpu" # CPU path (likely too big for R7B) def is_identity_query(message: str, history) -> bool: """Detects identity questions in current message or most recent user turn.""" patterns = [ r"\bwho\s+are\s+you\b", r"\bwhat\s+are\s+you\b", r"\bwhat\s+is\s+your\s+name\b", r"\bwho\s+is\s+this\b", r"\bidentify\s+yourself\b", r"\btell\s+me\s+about\s+yourself\b", r"\bdescribe\s+yourself\b", r"\band\s+you\s*\?\b", r"\byour\s+name\b", r"\bwho\s+am\s+i\s+chatting\s+with\b", ] def hit(text: str | None) -> bool: t = (text or "").strip().lower() return any(re.search(p, t) for p in patterns) if hit(message): return True if history: # Gradio history: List[Tuple[user, assistant]] last_user = history[-1][0] if isinstance(history[-1], (list, tuple)) and history[-1] else None if hit(last_user): return True return False # ------------------- # Cohere Hosted Path # ------------------- _co_client = None if USE_HOSTED_COHERE: _co_client = cohere.Client(api_key=COHERE_API_KEY) def _cohere_parse(resp): # v5+ responses.create if hasattr(resp, "output_text") and resp.output_text: return resp.output_text.strip() if getattr(resp, "message", None) and getattr(resp.message, "content", None): for p in resp.message.content: if hasattr(p, "text") and p.text: return p.text.strip() # v4 chat if hasattr(resp, "text") and resp.text: return resp.text.strip() return "Sorry, I couldn't parse the response from Cohere." def cohere_chat(message, history): try: # Prefer modern API try: msgs = [] for u, a in (history or []): msgs.append({"role": "user", "content": u}) msgs.append({"role": "assistant", "content": a}) msgs.append({"role": "user", "content": message}) resp = _co_client.responses.create( model="command-r7b-12-2024", messages=msgs, temperature=0.3, max_tokens=350, ) except Exception: # Fallback to older chat API resp = _co_client.chat( model="command-r7b-12-2024", message=message, temperature=0.3, max_tokens=350, ) return _cohere_parse(resp) except Exception as e: return f"Error calling Cohere API: {e}" # ------------------- # Local HF Path # ------------------- @lru_cache(maxsize=1) def load_local_model(): if not HF_TOKEN: raise RuntimeError( "HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is not set. " "Either set it, or provide COHERE_API_KEY to use Cohere's hosted API." ) login(token=HF_TOKEN, add_to_git_credential=False) dtype, device_map = pick_dtype_and_map() tok = AutoTokenizer.from_pretrained( MODEL_ID, token=HF_TOKEN, use_fast=True, model_max_length=4096, padding_side="left", trust_remote_code=True, ) mdl = AutoModelForCausalLM.from_pretrained( MODEL_ID, token=HF_TOKEN, device_map=device_map, low_cpu_mem_usage=True, torch_dtype=dtype, trust_remote_code=True, ) if mdl.config.eos_token_id is None and tok.eos_token_id is not None: mdl.config.eos_token_id = tok.eos_token_id return mdl, tok def build_inputs(tokenizer, message, history): msgs = [] for u, a in (history or []): msgs.append({"role": "user", "content": u}) msgs.append({"role": "assistant", "content": a}) msgs.append({"role": "user", "content": message}) return tokenizer.apply_chat_template( msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt" ) def local_generate(model, tokenizer, input_ids, max_new_tokens=350): input_ids = input_ids.to(model.device) with torch.no_grad(): out = model.generate( input_ids=input_ids, max_new_tokens=max_new_tokens, do_sample=True, temperature=0.3, top_p=0.9, repetition_penalty=1.15, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) gen_only = out[0, input_ids.shape[-1]:] text = tokenizer.decode(gen_only, skip_special_tokens=True) return text.strip() # ------------------- # Chat callback (no header/meta in chat replies) # ------------------- def chat_fn(message, history, user_tz): try: # Identity override → return ONLY the brand line if is_identity_query(message, history): return "I am ClarityOps, your strategic decision making AI partner." if USE_HOSTED_COHERE: return cohere_chat(message, history) model, tokenizer = load_local_model() inputs = build_inputs(tokenizer, message, history) return local_generate(model, tokenizer, inputs, max_new_tokens=350) except RuntimeError as e: emsg = str(e) if "out of memory" in emsg.lower() or "cuda" in emsg.lower(): return "Local load likely OOM. Use a GPU Space or set COHERE_API_KEY to run via Cohere hosted API." return f"Error during chat: {e}" except Exception as e: return f"Error during chat: {e}" # ------------------- # THEME & STYLES # ------------------- theme = gr.themes.Soft( primary_hue="teal", neutral_hue="slate", radius_size=gr.themes.sizes.radius_lg, ).set( # Typeface & sizes tuned for executive readability body_text_size="16px", heading_text_size="28px", shadow_drop="0 6px 24px rgba(0,0,0,.06)", shadow_spread="0 2px 8px rgba(0,0,0,.04)", ) custom_css = """ :root { --brand-bg: #f6fbfb; --brand-card: #ffffff; --brand-text: #0f172a; /* slate-900 */ --brand-subtle: #475569; /* slate-600 */ --brand-accent: #0d9488; /* teal-600 */ --brand-accent-weak: #99f6e4; /* teal-200 */ --brand-border: #e2e8f0; /* slate-200 */ } /* Page background and layout */ .gradio-container { background: var(--brand-bg); } /* Title */ h1, .prose h1 { color: var(--brand-text); font-weight: 700; letter-spacing: -0.01em; margin-bottom: 0.25rem !important; } /* Status badge wrapper */ .status-wrap { display: flex; align-items: center; gap: .5rem; margin-bottom: 0.75rem; } /* Badge */ .badge { display: inline-flex; align-items: center; gap: .5rem; padding: .45rem .75rem; border-radius: 999px; border: 1px solid var(--brand-border); background: #ecfdf5; /* green-50 */ color: #065f46; /* green-800 */ font-weight: 600; font-size: 14px; } /* Description / helper text */ .helper { color: var(--brand-subtle); margin: .25rem 0 1rem 0; } /* Card polishing */ .block, .gr-box, .gr-panel, .gr-group, .gr-form, .gradio-container .form { border-radius: 16px !important; } /* Chat area spacing */ #chat-root .wrap { padding: 0 !important; } /* Chat bubbles (subtle) */ .message.user { background: #f8fafc !important; /* slate-50 */ } .message.bot { background: #ffffff !important; } /* Inputs */ textarea, input, .gr-input { border-radius: 12px !important; } """ # ------------------- # UI # ------------------- with gr.Blocks(theme=theme, css=custom_css) as demo: # Hidden textbox to hold browser timezone (Gradio expects components for outputs) tz_box = gr.Textbox(visible=False) # On load, capture browser timezone via JS and write it into tz_box demo.load( fn=lambda tz: tz, # echo JS value to Python inputs=[tz_box], # 1 input required for lambda outputs=[tz_box], # write into same hidden box js="() => Intl.DateTimeFormat().resolvedOptions().timeZone" ) # Model status (auto, no button) def model_status(_user_tz): try: if USE_HOSTED_COHERE: return ( '