Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import time | |
| from datetime import datetime, timezone | |
| from functools import lru_cache | |
| import gradio as gr | |
| import torch | |
| # Try to import Cohere SDK if present (for hosted path) | |
| try: | |
| import cohere # pip install cohere | |
| _HAS_COHERE = True | |
| except Exception: | |
| _HAS_COHERE = False | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from huggingface_hub import login, HfApi | |
| # ------------------- | |
| # Configuration | |
| # ------------------- | |
| MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024") | |
| HF_TOKEN = ( | |
| os.getenv("HUGGINGFACE_HUB_TOKEN") # official Spaces name | |
| or os.getenv("HF_TOKEN") | |
| ) | |
| COHERE_API_KEY = os.getenv("COHERE_API_KEY") | |
| USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE) | |
| # ------------------- | |
| # Helpers | |
| # ------------------- | |
| def utc_now(): | |
| return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") | |
| def header(processing_time=None): | |
| s = ( | |
| f"Current Date and Time (UTC - YYYY-MM-DD HH:MM:SS formatted): {utc_now()} " | |
| f"Current User's Login: Raj-VedAI\n" | |
| ) | |
| if processing_time is not None: | |
| s += f"Processing Time: {processing_time:.2f} seconds\n" | |
| return s | |
| def pick_dtype_and_map(): | |
| if torch.cuda.is_available(): | |
| return torch.float16, "auto" | |
| if torch.backends.mps.is_available(): | |
| return torch.float16, {"": "mps"} | |
| return torch.float32, "cpu" # CPU path (likely too big for R7B) | |
| # ------------------- | |
| # Cohere Hosted Path | |
| # ------------------- | |
| _co_client = None | |
| if USE_HOSTED_COHERE: | |
| _co_client = cohere.Client(api_key=COHERE_API_KEY) | |
| def _cohere_parse(resp): | |
| """ | |
| Handle both Cohere SDK styles: | |
| - responses.create(...): resp.output_text or resp.message.content[0].text | |
| - chat(...): resp.text | |
| """ | |
| # v5+ responses.create | |
| if hasattr(resp, "output_text") and resp.output_text: | |
| return resp.output_text.strip() | |
| if getattr(resp, "message", None) and getattr(resp.message, "content", None): | |
| parts = resp.message.content | |
| # pick first text part | |
| for p in parts: | |
| if hasattr(p, "text") and p.text: | |
| return p.text.strip() | |
| # v4 chat | |
| if hasattr(resp, "text") and resp.text: | |
| return resp.text.strip() | |
| return "Sorry, I couldn't parse the response from Cohere." | |
| def cohere_chat(message, history): | |
| # Build a clean user prompt from history (simple, safe) | |
| # If you want structured history, you can pass messages when using responses.create | |
| try: | |
| # Try modern API first | |
| try: | |
| msgs = [] | |
| for u, a in (history or []): | |
| msgs.append({"role": "user", "content": u}) | |
| msgs.append({"role": "assistant", "content": a}) | |
| msgs.append({"role": "user", "content": message}) | |
| resp = _co_client.responses.create( | |
| model="command-r7b-12-2024", | |
| messages=msgs, | |
| temperature=0.3, | |
| max_tokens=350, | |
| ) | |
| except Exception: | |
| # Fallback to older chat API | |
| resp = _co_client.chat( | |
| model="command-r7b-12-2024", | |
| message=message, | |
| temperature=0.3, | |
| max_tokens=350, | |
| ) | |
| return _cohere_parse(resp) | |
| except Exception as e: | |
| return f"Error calling Cohere API: {e}" | |
| # ------------------- | |
| # Local HF Path | |
| # ------------------- | |
| def load_local_model(): | |
| if not HF_TOKEN: | |
| raise RuntimeError( | |
| "HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is not set. " | |
| "Either set it, or provide COHERE_API_KEY to use Cohere's hosted API." | |
| ) | |
| login(token=HF_TOKEN, add_to_git_credential=False) | |
| dtype, device_map = pick_dtype_and_map() | |
| tok = AutoTokenizer.from_pretrained( | |
| MODEL_ID, | |
| token=HF_TOKEN, | |
| use_fast=True, | |
| model_max_length=4096, | |
| padding_side="left", | |
| trust_remote_code=True, | |
| ) | |
| mdl = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| token=HF_TOKEN, | |
| device_map=device_map, | |
| low_cpu_mem_usage=True, | |
| torch_dtype=dtype, | |
| trust_remote_code=True, | |
| ) | |
| if mdl.config.eos_token_id is None and tok.eos_token_id is not None: | |
| mdl.config.eos_token_id = tok.eos_token_id | |
| return mdl, tok | |
| def build_inputs(tokenizer, message, history): | |
| msgs = [] | |
| for u, a in (history or []): | |
| msgs.append({"role": "user", "content": u}) | |
| msgs.append({"role": "assistant", "content": a}) | |
| msgs.append({"role": "user", "content": message}) | |
| return tokenizer.apply_chat_template( | |
| msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt" | |
| ) | |
| def local_generate(model, tokenizer, input_ids, max_new_tokens=350): | |
| input_ids = input_ids.to(model.device) | |
| with torch.no_grad(): | |
| out = model.generate( | |
| input_ids=input_ids, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=0.3, | |
| top_p=0.9, | |
| repetition_penalty=1.15, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| gen_only = out[0, input_ids.shape[-1]:] | |
| text = tokenizer.decode(gen_only, skip_special_tokens=True) | |
| return text.strip() | |
| # ------------------- | |
| # Chat callback | |
| # ------------------- | |
| def chat_fn(message, history): | |
| t0 = time.time() | |
| try: | |
| if USE_HOSTED_COHERE: | |
| reply = cohere_chat(message, history) | |
| return f"{header(time.time() - t0)}{reply}" | |
| # Local load (GPU strongly recommended; CPU likely OOM for R7B) | |
| model, tokenizer = load_local_model() | |
| inputs = build_inputs(tokenizer, message, history) | |
| reply = local_generate(model, tokenizer, inputs, max_new_tokens=350) | |
| return f"{header(time.time() - t0)}{reply}" | |
| except RuntimeError as e: | |
| emsg = str(e) | |
| if "out of memory" in emsg.lower() or "cuda" in emsg.lower(): | |
| return ( | |
| f"{header(time.time() - t0)}Local load likely OOM. " | |
| "Use a GPU Space or set COHERE_API_KEY to run via Cohere hosted API." | |
| ) | |
| return f"{header(time.time() - t0)}Error during chat: {e}" | |
| except Exception as e: | |
| return f"{header(time.time() - t0)}Error during chat: {e}" | |
| # ------------------- | |
| # Connection check | |
| # ------------------- | |
| def check_connection(): | |
| try: | |
| mode = "Cohere API (hosted)" if USE_HOSTED_COHERE else "Local HF" | |
| if USE_HOSTED_COHERE: | |
| return ( | |
| f"{header()}" | |
| f"Connection Status: ✅ Using Cohere hosted API\n" | |
| f"Mode: {mode}\n" | |
| f"Model: command-r7b-12-2024\n" | |
| ) | |
| # Local HF metadata | |
| api = HfApi(token=HF_TOKEN) | |
| mi = api.model_info(MODEL_ID) | |
| return ( | |
| f"{header()}" | |
| f"Connection Status: ✅ Connected\n" | |
| f"Mode: {mode}\n" | |
| f"Model: {mi.modelId}\n" | |
| f"Last Modified: {mi.lastModified}\n" | |
| ) | |
| except Exception as e: | |
| return f"{header()}Connection Status: ❌ Error\nDetails: {e}" | |
| # ------------------- | |
| # UI | |
| # ------------------- | |
| with gr.Blocks(theme=gr.themes.Default()) as demo: | |
| gr.Markdown(f"# Medical Decision Support AI\n{header()}") | |
| with gr.Row(): | |
| btn = gr.Button("Check Connection Status") | |
| status = gr.Textbox(label="Connection Status", lines=7, value="Click to check…") | |
| gr.Markdown( | |
| "⚙️ First response may take a moment while the model warms up. " | |
| "Currently configured to use **Cohere hosted API** if `COHERE_API_KEY` is set; " | |
| "otherwise, tries **local HF**." | |
| ) | |
| chat = gr.ChatInterface( | |
| fn=chat_fn, | |
| type="messages", | |
| description="A medical decision support system that provides healthcare-related information and guidance.", | |
| examples=[ | |
| "What are the symptoms of hypertension?", | |
| "What are common drug interactions with aspirin?", | |
| "What are the warning signs of diabetes?", | |
| ], | |
| ) | |
| btn.click(fn=check_connection, outputs=status) | |
| if __name__ == "__main__": | |
| # You can disable SSR if it conflicts in your Space: | |
| demo.launch() | |