Rajan Sharma commited on
Commit
b23412f
·
verified ·
1 Parent(s): 88b3626

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +150 -32
app.py CHANGED
@@ -4,24 +4,43 @@ import time
4
  from datetime import datetime, timezone
5
  from functools import lru_cache
6
 
7
- import torch
8
  import gradio as gr
 
 
 
 
 
 
 
 
 
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
  from huggingface_hub import login, HfApi
11
 
12
- # ---- Config ----
 
 
 
13
  MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024")
 
14
  HF_TOKEN = (
15
- os.getenv("HUGGINGFACE_HUB_TOKEN") # canonical name in HF Spaces
16
  or os.getenv("HF_TOKEN")
17
  )
18
 
19
- def utc_now() -> str:
 
 
 
 
 
 
 
20
  return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
21
 
22
- def header(processing_time=None) -> str:
23
  s = (
24
- f"Current Date and Time (UTC - YYYY-MM-DD HH:MM:SS formatted): {utc_now()}\n"
25
  f"Current User's Login: Raj-VedAI\n"
26
  )
27
  if processing_time is not None:
@@ -33,39 +52,99 @@ def pick_dtype_and_map():
33
  return torch.float16, "auto"
34
  if torch.backends.mps.is_available():
35
  return torch.float16, {"": "mps"}
36
- return torch.float32, "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
 
 
 
38
  @lru_cache(maxsize=1)
39
- def load_model():
40
- # Login (optional for public models; safe if token is unset)
41
- if HF_TOKEN:
42
- login(token=HF_TOKEN, add_to_git_credential=False)
 
 
43
 
44
- dtype, device_map = pick_dtype_and_map()
45
 
46
- tokenizer = AutoTokenizer.from_pretrained(
 
47
  MODEL_ID,
48
  token=HF_TOKEN,
49
  use_fast=True,
50
  model_max_length=4096,
51
  padding_side="left",
52
- trust_remote_code=True, # <- allow custom model code
53
  )
54
-
55
- model = AutoModelForCausalLM.from_pretrained(
56
  MODEL_ID,
57
  token=HF_TOKEN,
58
  device_map=device_map,
59
  low_cpu_mem_usage=True,
60
  torch_dtype=dtype,
61
- trust_remote_code=True, # <- allow custom model code
62
  )
63
-
64
- # Ensure EOS configured
65
- if model.config.eos_token_id is None and tokenizer.eos_token_id is not None:
66
- model.config.eos_token_id = tokenizer.eos_token_id
67
-
68
- return model, tokenizer
69
 
70
  def build_inputs(tokenizer, message, history):
71
  msgs = []
@@ -74,13 +153,10 @@ def build_inputs(tokenizer, message, history):
74
  msgs.append({"role": "assistant", "content": a})
75
  msgs.append({"role": "user", "content": message})
76
  return tokenizer.apply_chat_template(
77
- msgs,
78
- tokenize=True,
79
- add_generation_prompt=True,
80
- return_tensors="pt",
81
  )
82
 
83
- def generate_reply(model, tokenizer, input_ids, max_new_tokens=300):
84
  input_ids = input_ids.to(model.device)
85
  with torch.no_grad():
86
  out = model.generate(
@@ -97,37 +173,77 @@ def generate_reply(model, tokenizer, input_ids, max_new_tokens=300):
97
  text = tokenizer.decode(gen_only, skip_special_tokens=True)
98
  return text.strip()
99
 
 
 
 
 
100
  def chat_fn(message, history):
101
  t0 = time.time()
102
  try:
103
- model, tokenizer = load_model()
 
 
 
 
 
104
  inputs = build_inputs(tokenizer, message, history)
105
- reply = generate_reply(model, tokenizer, inputs, max_new_tokens=350)
106
  return f"{header(time.time() - t0)}{reply}"
 
 
 
 
 
 
 
 
 
107
  except Exception as e:
108
  return f"{header(time.time() - t0)}Error during chat: {e}"
109
 
 
 
 
 
110
  def check_connection():
111
  try:
 
 
 
 
 
 
 
 
 
112
  api = HfApi(token=HF_TOKEN)
113
  mi = api.model_info(MODEL_ID)
114
  return (
115
  f"{header()}"
116
  f"Connection Status: ✅ Connected\n"
 
117
  f"Model: {mi.modelId}\n"
118
  f"Last Modified: {mi.lastModified}\n"
119
  )
120
  except Exception as e:
121
  return f"{header()}Connection Status: ❌ Error\nDetails: {e}"
122
 
 
 
 
 
123
  with gr.Blocks(theme=gr.themes.Default()) as demo:
124
  gr.Markdown(f"# Medical Decision Support AI\n{header()}")
125
 
126
  with gr.Row():
127
  btn = gr.Button("Check Connection Status")
128
- status = gr.Textbox(label="Connection Status", lines=6, value="Click to check…")
129
 
130
- gr.Markdown("⚙️ First response may take a moment while the model warms up.")
 
 
 
 
131
 
132
  chat = gr.ChatInterface(
133
  fn=chat_fn,
@@ -143,6 +259,8 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
143
  btn.click(fn=check_connection, outputs=status)
144
 
145
  if __name__ == "__main__":
 
146
  demo.launch()
147
 
148
 
 
 
4
  from datetime import datetime, timezone
5
  from functools import lru_cache
6
 
 
7
  import gradio as gr
8
+ import torch
9
+
10
+ # Try to import Cohere SDK if present (for hosted path)
11
+ try:
12
+ import cohere # pip install cohere
13
+ _HAS_COHERE = True
14
+ except Exception:
15
+ _HAS_COHERE = False
16
+
17
  from transformers import AutoTokenizer, AutoModelForCausalLM
18
  from huggingface_hub import login, HfApi
19
 
20
+
21
+ # -------------------
22
+ # Configuration
23
+ # -------------------
24
  MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024")
25
+
26
  HF_TOKEN = (
27
+ os.getenv("HUGGINGFACE_HUB_TOKEN") # official Spaces name
28
  or os.getenv("HF_TOKEN")
29
  )
30
 
31
+ COHERE_API_KEY = os.getenv("COHERE_API_KEY")
32
+
33
+ USE_HOSTED_COHERE = bool(COHERE_API_KEY and _HAS_COHERE)
34
+
35
+ # -------------------
36
+ # Helpers
37
+ # -------------------
38
+ def utc_now():
39
  return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
40
 
41
+ def header(processing_time=None):
42
  s = (
43
+ f"Current Date and Time (UTC - YYYY-MM-DD HH:MM:SS formatted): {utc_now()} "
44
  f"Current User's Login: Raj-VedAI\n"
45
  )
46
  if processing_time is not None:
 
52
  return torch.float16, "auto"
53
  if torch.backends.mps.is_available():
54
  return torch.float16, {"": "mps"}
55
+ return torch.float32, "cpu" # CPU path (likely too big for R7B)
56
+
57
+
58
+ # -------------------
59
+ # Cohere Hosted Path
60
+ # -------------------
61
+ _co_client = None
62
+ if USE_HOSTED_COHERE:
63
+ _co_client = cohere.Client(api_key=COHERE_API_KEY)
64
+
65
+ def _cohere_parse(resp):
66
+ """
67
+ Handle both Cohere SDK styles:
68
+ - responses.create(...): resp.output_text or resp.message.content[0].text
69
+ - chat(...): resp.text
70
+ """
71
+ # v5+ responses.create
72
+ if hasattr(resp, "output_text") and resp.output_text:
73
+ return resp.output_text.strip()
74
+ if getattr(resp, "message", None) and getattr(resp.message, "content", None):
75
+ parts = resp.message.content
76
+ # pick first text part
77
+ for p in parts:
78
+ if hasattr(p, "text") and p.text:
79
+ return p.text.strip()
80
+ # v4 chat
81
+ if hasattr(resp, "text") and resp.text:
82
+ return resp.text.strip()
83
+ return "Sorry, I couldn't parse the response from Cohere."
84
+
85
+ def cohere_chat(message, history):
86
+ # Build a clean user prompt from history (simple, safe)
87
+ # If you want structured history, you can pass messages when using responses.create
88
+ try:
89
+ # Try modern API first
90
+ try:
91
+ msgs = []
92
+ for u, a in (history or []):
93
+ msgs.append({"role": "user", "content": u})
94
+ msgs.append({"role": "assistant", "content": a})
95
+ msgs.append({"role": "user", "content": message})
96
+ resp = _co_client.responses.create(
97
+ model="command-r7b-12-2024",
98
+ messages=msgs,
99
+ temperature=0.3,
100
+ max_tokens=350,
101
+ )
102
+ except Exception:
103
+ # Fallback to older chat API
104
+ resp = _co_client.chat(
105
+ model="command-r7b-12-2024",
106
+ message=message,
107
+ temperature=0.3,
108
+ max_tokens=350,
109
+ )
110
+ return _cohere_parse(resp)
111
+ except Exception as e:
112
+ return f"Error calling Cohere API: {e}"
113
+
114
 
115
+ # -------------------
116
+ # Local HF Path
117
+ # -------------------
118
  @lru_cache(maxsize=1)
119
+ def load_local_model():
120
+ if not HF_TOKEN:
121
+ raise RuntimeError(
122
+ "HUGGINGFACE_HUB_TOKEN (or HF_TOKEN) is not set. "
123
+ "Either set it, or provide COHERE_API_KEY to use Cohere's hosted API."
124
+ )
125
 
126
+ login(token=HF_TOKEN, add_to_git_credential=False)
127
 
128
+ dtype, device_map = pick_dtype_and_map()
129
+ tok = AutoTokenizer.from_pretrained(
130
  MODEL_ID,
131
  token=HF_TOKEN,
132
  use_fast=True,
133
  model_max_length=4096,
134
  padding_side="left",
135
+ trust_remote_code=True,
136
  )
137
+ mdl = AutoModelForCausalLM.from_pretrained(
 
138
  MODEL_ID,
139
  token=HF_TOKEN,
140
  device_map=device_map,
141
  low_cpu_mem_usage=True,
142
  torch_dtype=dtype,
143
+ trust_remote_code=True,
144
  )
145
+ if mdl.config.eos_token_id is None and tok.eos_token_id is not None:
146
+ mdl.config.eos_token_id = tok.eos_token_id
147
+ return mdl, tok
 
 
 
148
 
149
  def build_inputs(tokenizer, message, history):
150
  msgs = []
 
153
  msgs.append({"role": "assistant", "content": a})
154
  msgs.append({"role": "user", "content": message})
155
  return tokenizer.apply_chat_template(
156
+ msgs, tokenize=True, add_generation_prompt=True, return_tensors="pt"
 
 
 
157
  )
158
 
159
+ def local_generate(model, tokenizer, input_ids, max_new_tokens=350):
160
  input_ids = input_ids.to(model.device)
161
  with torch.no_grad():
162
  out = model.generate(
 
173
  text = tokenizer.decode(gen_only, skip_special_tokens=True)
174
  return text.strip()
175
 
176
+
177
+ # -------------------
178
+ # Chat callback
179
+ # -------------------
180
  def chat_fn(message, history):
181
  t0 = time.time()
182
  try:
183
+ if USE_HOSTED_COHERE:
184
+ reply = cohere_chat(message, history)
185
+ return f"{header(time.time() - t0)}{reply}"
186
+
187
+ # Local load (GPU strongly recommended; CPU likely OOM for R7B)
188
+ model, tokenizer = load_local_model()
189
  inputs = build_inputs(tokenizer, message, history)
190
+ reply = local_generate(model, tokenizer, inputs, max_new_tokens=350)
191
  return f"{header(time.time() - t0)}{reply}"
192
+
193
+ except RuntimeError as e:
194
+ emsg = str(e)
195
+ if "out of memory" in emsg.lower() or "cuda" in emsg.lower():
196
+ return (
197
+ f"{header(time.time() - t0)}Local load likely OOM. "
198
+ "Use a GPU Space or set COHERE_API_KEY to run via Cohere hosted API."
199
+ )
200
+ return f"{header(time.time() - t0)}Error during chat: {e}"
201
  except Exception as e:
202
  return f"{header(time.time() - t0)}Error during chat: {e}"
203
 
204
+
205
+ # -------------------
206
+ # Connection check
207
+ # -------------------
208
  def check_connection():
209
  try:
210
+ mode = "Cohere API (hosted)" if USE_HOSTED_COHERE else "Local HF"
211
+ if USE_HOSTED_COHERE:
212
+ return (
213
+ f"{header()}"
214
+ f"Connection Status: ✅ Using Cohere hosted API\n"
215
+ f"Mode: {mode}\n"
216
+ f"Model: command-r7b-12-2024\n"
217
+ )
218
+ # Local HF metadata
219
  api = HfApi(token=HF_TOKEN)
220
  mi = api.model_info(MODEL_ID)
221
  return (
222
  f"{header()}"
223
  f"Connection Status: ✅ Connected\n"
224
+ f"Mode: {mode}\n"
225
  f"Model: {mi.modelId}\n"
226
  f"Last Modified: {mi.lastModified}\n"
227
  )
228
  except Exception as e:
229
  return f"{header()}Connection Status: ❌ Error\nDetails: {e}"
230
 
231
+
232
+ # -------------------
233
+ # UI
234
+ # -------------------
235
  with gr.Blocks(theme=gr.themes.Default()) as demo:
236
  gr.Markdown(f"# Medical Decision Support AI\n{header()}")
237
 
238
  with gr.Row():
239
  btn = gr.Button("Check Connection Status")
240
+ status = gr.Textbox(label="Connection Status", lines=7, value="Click to check…")
241
 
242
+ gr.Markdown(
243
+ "⚙️ First response may take a moment while the model warms up. "
244
+ "Currently configured to use **Cohere hosted API** if `COHERE_API_KEY` is set; "
245
+ "otherwise, tries **local HF**."
246
+ )
247
 
248
  chat = gr.ChatInterface(
249
  fn=chat_fn,
 
259
  btn.click(fn=check_connection, outputs=status)
260
 
261
  if __name__ == "__main__":
262
+ # You can disable SSR if it conflicts in your Space:
263
  demo.launch()
264
 
265
 
266
+