Rajan Sharma
Update app.py
b23412f verified
raw
history blame
8.34 kB
# app.py
import os
import time
from datetime import datetime, timezone
from functools import lru_cache
import gradio as gr
import torch
# Try to import Cohere SDK if present (for hosted path)
try:
import cohere # pip install cohere
_HAS_COHERE = True
except Exception:
_HAS_COHERE = False
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login, HfApi
# -------------------
# Configuration
# -------------------
MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024")
HF_TOKEN = (
os.getenv("HUGGINGFACE_HUB_TOKEN") # official Spaces name
or os.getenv("HF_TOKEN")
)
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
# -------------------
# Helpers
# -------------------
def utc_now():
return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
def header(processing_time=None):
s = (
f"Current Date and Time (UTC - YYYY-MM-DD HH:MM:SS formatted): {utc_now()} "
f"Current User's Login: Raj-VedAI\n"
)
if processing_time is not None:
s += f"Processing Time: {processing_time:.2f} seconds\n"
return s
def pick_dtype_and_map():
if torch.cuda.is_available():
return torch.float16, "auto"
if torch.backends.mps.is_available():
return torch.float16, {"": "mps"}
return torch.float32, "cpu" # CPU path (likely too big for R7B)
# -------------------
# Cohere Hosted Path
# -------------------
_co_client = None
if USE_HOSTED_COHERE:
_co_client = cohere.Client(api_key=COHERE_API_KEY)
def _cohere_parse(resp):
"""
Handle both Cohere SDK styles:
- responses.create(...): resp.output_text or resp.message.content[0].text
- chat(...): resp.text
"""
# v5+ responses.create
if hasattr(resp, "output_text") and resp.output_text:
return resp.output_text.strip()
if getattr(resp, "message", None) and getattr(resp.message, "content", None):
parts = resp.message.content
# pick first text part
for p in parts:
if hasattr(p, "text") and p.text:
return p.text.strip()
# v4 chat
if hasattr(resp, "text") and resp.text:
return resp.text.strip()
return "Sorry, I couldn't parse the response from Cohere."
def cohere_chat(message, history):
# Build a clean user prompt from history (simple, safe)
# If you want structured history, you can pass messages when using responses.create
try:
# Try modern API first
try:
msgs = []
for u, a in (history or []):
msgs.append({"role": "user", "content": u})
msgs.append({"role": "assistant", "content": a})
msgs.append({"role": "user", "content": message})
resp = _co_client.responses.create(
model="command-r7b-12-2024",
messages=msgs,
temperature=0.3,
max_tokens=350,
)
except Exception:
# Fallback to older chat API
resp = _co_client.chat(
model="command-r7b-12-2024",
message=message,
temperature=0.3,
max_tokens=350,
)
return _cohere_parse(resp)
except Exception as e:
return f"Error calling Cohere API: {e}"
# -------------------
# Local HF Path
# -------------------
@lru_cache(maxsize=1)
def load_local_model():
if not HF_TOKEN:
raise RuntimeError(
"HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is not set. "
"Either set it, or provide COHERE_API_KEY to use Cohere's hosted API."
)
login(token=HF_TOKEN, add_to_git_credential=False)
dtype, device_map = pick_dtype_and_map()
tok = AutoTokenizer.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
use_fast=True,
model_max_length=4096,
padding_side="left",
trust_remote_code=True,
)
mdl = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
token=HF_TOKEN,
device_map=device_map,
low_cpu_mem_usage=True,
torch_dtype=dtype,
trust_remote_code=True,
)
if mdl.config.eos_token_id is None and tok.eos_token_id is not None:
mdl.config.eos_token_id = tok.eos_token_id
return mdl, tok
def build_inputs(tokenizer, message, history):
msgs = []
for u, a in (history or []):
msgs.append({"role": "user", "content": u})
msgs.append({"role": "assistant", "content": a})
msgs.append({"role": "user", "content": message})
return tokenizer.apply_chat_template(
msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt"
)
def local_generate(model, tokenizer, input_ids, max_new_tokens=350):
input_ids = input_ids.to(model.device)
with torch.no_grad():
out = model.generate(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.3,
top_p=0.9,
repetition_penalty=1.15,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
gen_only = out[0, input_ids.shape[-1]:]
text = tokenizer.decode(gen_only, skip_special_tokens=True)
return text.strip()
# -------------------
# Chat callback
# -------------------
def chat_fn(message, history):
t0 = time.time()
try:
if USE_HOSTED_COHERE:
reply = cohere_chat(message, history)
return f"{header(time.time() - t0)}{reply}"
# Local load (GPU strongly recommended; CPU likely OOM for R7B)
model, tokenizer = load_local_model()
inputs = build_inputs(tokenizer, message, history)
reply = local_generate(model, tokenizer, inputs, max_new_tokens=350)
return f"{header(time.time() - t0)}{reply}"
except RuntimeError as e:
emsg = str(e)
if "out of memory" in emsg.lower() or "cuda" in emsg.lower():
return (
f"{header(time.time() - t0)}Local load likely OOM. "
"Use a GPU Space or set COHERE_API_KEY to run via Cohere hosted API."
)
return f"{header(time.time() - t0)}Error during chat: {e}"
except Exception as e:
return f"{header(time.time() - t0)}Error during chat: {e}"
# -------------------
# Connection check
# -------------------
def check_connection():
try:
mode = "Cohere API (hosted)" if USE_HOSTED_COHERE else "Local HF"
if USE_HOSTED_COHERE:
return (
f"{header()}"
f"Connection Status: ✅ Using Cohere hosted API\n"
f"Mode: {mode}\n"
f"Model: command-r7b-12-2024\n"
)
# Local HF metadata
api = HfApi(token=HF_TOKEN)
mi = api.model_info(MODEL_ID)
return (
f"{header()}"
f"Connection Status: ✅ Connected\n"
f"Mode: {mode}\n"
f"Model: {mi.modelId}\n"
f"Last Modified: {mi.lastModified}\n"
)
except Exception as e:
return f"{header()}Connection Status: ❌ Error\nDetails: {e}"
# -------------------
# UI
# -------------------
with gr.Blocks(theme=gr.themes.Default()) as demo:
gr.Markdown(f"# Medical Decision Support AI\n{header()}")
with gr.Row():
btn = gr.Button("Check Connection Status")
status = gr.Textbox(label="Connection Status", lines=7, value="Click to check…")
gr.Markdown(
"⚙️ First response may take a moment while the model warms up. "
"Currently configured to use **Cohere hosted API** if `COHERE_API_KEY` is set; "
"otherwise, tries **local HF**."
)
chat = gr.ChatInterface(
fn=chat_fn,
type="messages",
description="A medical decision support system that provides healthcare-related information and guidance.",
examples=[
"What are the symptoms of hypertension?",
"What are common drug interactions with aspirin?",
"What are the warning signs of diabetes?",
],
)
btn.click(fn=check_connection, outputs=status)
if __name__ == "__main__":
# You can disable SSR if it conflicts in your Space:
demo.launch()