Rajan Sharma commited on
Commit
1c47f55
·
verified ·
1 Parent(s): bcc9046

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -39
app.py CHANGED
@@ -9,79 +9,78 @@ import gradio as gr
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
  from huggingface_hub import login, HfApi
11
 
12
- MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-a-03-2025") # change if needed
 
13
  HF_TOKEN = (
14
- os.getenv("HUGGINGFACE_HUB_TOKEN") # <-- correct canonical name
15
  or os.getenv("HF_TOKEN")
16
  )
17
 
18
- def get_timestamp():
19
  return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
20
 
21
- def format_system_info(processing_time=None):
22
- info = (
23
- f"Current Date and Time (UTC - YYYY-MM-DD HH:MM:SS formatted): {get_timestamp()}\n"
24
  f"Current User's Login: Raj-VedAI\n"
25
  )
26
  if processing_time is not None:
27
- info += f"Processing Time: {processing_time:.2f} seconds\n"
28
- return info
29
 
30
- def _pick_dtype_and_map():
31
  if torch.cuda.is_available():
32
  return torch.float16, "auto"
33
  if torch.backends.mps.is_available():
34
- # Apple Silicon (MPS) prefers float16/bfloat16 depending on model; float16 is usually OK.
35
  return torch.float16, {"": "mps"}
36
- return torch.float32, "cpu" # CPU-safe
37
 
38
  @lru_cache(maxsize=1)
39
  def load_model():
 
40
  if HF_TOKEN:
41
- # In Spaces this isn’t strictly necessary if the secret is set, but it doesn’t hurt.
42
  login(token=HF_TOKEN, add_to_git_credential=False)
43
 
44
- dtype, device_map = _pick_dtype_and_map()
45
 
46
- tok = AutoTokenizer.from_pretrained(
47
  MODEL_ID,
48
  token=HF_TOKEN,
49
  use_fast=True,
50
  model_max_length=4096,
51
- padding_side="left", # safer for some chat templates
 
52
  )
53
 
54
- mdl = AutoModelForCausalLM.from_pretrained(
55
  MODEL_ID,
56
  token=HF_TOKEN,
57
  device_map=device_map,
58
  low_cpu_mem_usage=True,
59
  torch_dtype=dtype,
 
60
  )
61
 
62
- # Fallback for models without an EOS defined
63
- if mdl.config.eos_token_id is None and tok.eos_token_id is not None:
64
- mdl.config.eos_token_id = tok.eos_token_id
65
 
66
- return mdl, tok
67
 
68
  def build_inputs(tokenizer, message, history):
69
- # Convert Gradio’s (message, history) into a chat template
70
  msgs = []
71
- # Optionally carry past turns if your model supports it
72
- for u, a in history or []:
73
  msgs.append({"role": "user", "content": u})
74
  msgs.append({"role": "assistant", "content": a})
75
  msgs.append({"role": "user", "content": message})
76
- inputs = tokenizer.apply_chat_template(
77
  msgs,
78
  tokenize=True,
79
  add_generation_prompt=True,
80
  return_tensors="pt",
81
  )
82
- return inputs
83
 
84
- def generate_reply(model, tokenizer, input_ids, max_new_tokens=256):
85
  input_ids = input_ids.to(model.device)
86
  with torch.no_grad():
87
  out = model.generate(
@@ -90,50 +89,49 @@ def generate_reply(model, tokenizer, input_ids, max_new_tokens=256):
90
  do_sample=True,
91
  temperature=0.3,
92
  top_p=0.9,
93
- repetition_penalty=1.2,
94
  pad_token_id=tokenizer.eos_token_id,
95
  eos_token_id=tokenizer.eos_token_id,
96
  )
97
- # Slice off the prompt so we only return new tokens
98
  gen_only = out[0, input_ids.shape[-1]:]
99
  text = tokenizer.decode(gen_only, skip_special_tokens=True)
100
  return text.strip()
101
 
102
  def chat_fn(message, history):
103
- start = time.time()
104
  try:
105
  model, tokenizer = load_model()
106
  inputs = build_inputs(tokenizer, message, history)
107
- reply = generate_reply(model, tokenizer, inputs, max_new_tokens=300)
108
- # Optional: prepend system info once per turn
109
- reply = f"{format_system_info(time.time() - start)}{reply}"
110
- return reply
111
  except Exception as e:
112
- return f"{format_system_info(time.time() - start)}Error during chat: {e}"
113
 
114
  def check_connection():
115
  try:
116
  api = HfApi(token=HF_TOKEN)
117
  mi = api.model_info(MODEL_ID)
118
  return (
119
- f"{format_system_info()}"
120
  f"Connection Status: ✅ Connected\n"
121
  f"Model: {mi.modelId}\n"
122
  f"Last Modified: {mi.lastModified}\n"
123
  )
124
  except Exception as e:
125
- return f"{format_system_info()}Connection Status: ❌ Error\nDetails: {e}"
126
 
127
  with gr.Blocks(theme=gr.themes.Default()) as demo:
128
- gr.Markdown(f"# Medical Decision Support AI\n{format_system_info()}")
 
129
  with gr.Row():
130
  btn = gr.Button("Check Connection Status")
131
  status = gr.Textbox(label="Connection Status", lines=6, value="Click to check…")
132
- gr.Markdown("⚙️ Model is loading on first request. Please wait for the first answer.")
 
133
 
134
  chat = gr.ChatInterface(
135
  fn=chat_fn,
136
- type="messages", # use the modern message format
137
  description="A medical decision support system that provides healthcare-related information and guidance.",
138
  examples=[
139
  "What are the symptoms of hypertension?",
@@ -147,3 +145,4 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
147
  if __name__ == "__main__":
148
  demo.launch()
149
 
 
 
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
10
  from huggingface_hub import login, HfApi
11
 
12
+ # ---- Config ----
13
+ MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-r7b-12-2024")
14
  HF_TOKEN = (
15
+ os.getenv("HUGGINGFACE_HUB_TOKEN") # canonical name in HF Spaces
16
  or os.getenv("HF_TOKEN")
17
  )
18
 
19
+ def utc_now() -> str:
20
  return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
21
 
22
+ def header(processing_time=None) -> str:
23
+ s = (
24
+ f"Current Date and Time (UTC - YYYY-MM-DD HH:MM:SS formatted): {utc_now()}\n"
25
  f"Current User's Login: Raj-VedAI\n"
26
  )
27
  if processing_time is not None:
28
+ s += f"Processing Time: {processing_time:.2f} seconds\n"
29
+ return s
30
 
31
+ def pick_dtype_and_map():
32
  if torch.cuda.is_available():
33
  return torch.float16, "auto"
34
  if torch.backends.mps.is_available():
 
35
  return torch.float16, {"": "mps"}
36
+ return torch.float32, "cpu"
37
 
38
  @lru_cache(maxsize=1)
39
  def load_model():
40
+ # Login (optional for public models; safe if token is unset)
41
  if HF_TOKEN:
 
42
  login(token=HF_TOKEN, add_to_git_credential=False)
43
 
44
+ dtype, device_map = pick_dtype_and_map()
45
 
46
+ tokenizer = AutoTokenizer.from_pretrained(
47
  MODEL_ID,
48
  token=HF_TOKEN,
49
  use_fast=True,
50
  model_max_length=4096,
51
+ padding_side="left",
52
+ trust_remote_code=True, # <- allow custom model code
53
  )
54
 
55
+ model = AutoModelForCausalLM.from_pretrained(
56
  MODEL_ID,
57
  token=HF_TOKEN,
58
  device_map=device_map,
59
  low_cpu_mem_usage=True,
60
  torch_dtype=dtype,
61
+ trust_remote_code=True, # <- allow custom model code
62
  )
63
 
64
+ # Ensure EOS configured
65
+ if model.config.eos_token_id is None and tokenizer.eos_token_id is not None:
66
+ model.config.eos_token_id = tokenizer.eos_token_id
67
 
68
+ return model, tokenizer
69
 
70
  def build_inputs(tokenizer, message, history):
 
71
  msgs = []
72
+ for u, a in (history or []):
 
73
  msgs.append({"role": "user", "content": u})
74
  msgs.append({"role": "assistant", "content": a})
75
  msgs.append({"role": "user", "content": message})
76
+ return tokenizer.apply_chat_template(
77
  msgs,
78
  tokenize=True,
79
  add_generation_prompt=True,
80
  return_tensors="pt",
81
  )
 
82
 
83
+ def generate_reply(model, tokenizer, input_ids, max_new_tokens=300):
84
  input_ids = input_ids.to(model.device)
85
  with torch.no_grad():
86
  out = model.generate(
 
89
  do_sample=True,
90
  temperature=0.3,
91
  top_p=0.9,
92
+ repetition_penalty=1.15,
93
  pad_token_id=tokenizer.eos_token_id,
94
  eos_token_id=tokenizer.eos_token_id,
95
  )
 
96
  gen_only = out[0, input_ids.shape[-1]:]
97
  text = tokenizer.decode(gen_only, skip_special_tokens=True)
98
  return text.strip()
99
 
100
  def chat_fn(message, history):
101
+ t0 = time.time()
102
  try:
103
  model, tokenizer = load_model()
104
  inputs = build_inputs(tokenizer, message, history)
105
+ reply = generate_reply(model, tokenizer, inputs, max_new_tokens=350)
106
+ return f"{header(time.time() - t0)}{reply}"
 
 
107
  except Exception as e:
108
+ return f"{header(time.time() - t0)}Error during chat: {e}"
109
 
110
  def check_connection():
111
  try:
112
  api = HfApi(token=HF_TOKEN)
113
  mi = api.model_info(MODEL_ID)
114
  return (
115
+ f"{header()}"
116
  f"Connection Status: ✅ Connected\n"
117
  f"Model: {mi.modelId}\n"
118
  f"Last Modified: {mi.lastModified}\n"
119
  )
120
  except Exception as e:
121
+ return f"{header()}Connection Status: ❌ Error\nDetails: {e}"
122
 
123
  with gr.Blocks(theme=gr.themes.Default()) as demo:
124
+ gr.Markdown(f"# Medical Decision Support AI\n{header()}")
125
+
126
  with gr.Row():
127
  btn = gr.Button("Check Connection Status")
128
  status = gr.Textbox(label="Connection Status", lines=6, value="Click to check…")
129
+
130
+ gr.Markdown("⚙️ First response may take a moment while the model warms up.")
131
 
132
  chat = gr.ChatInterface(
133
  fn=chat_fn,
134
+ type="messages",
135
  description="A medical decision support system that provides healthcare-related information and guidance.",
136
  examples=[
137
  "What are the symptoms of hypertension?",
 
145
  if __name__ == "__main__":
146
  demo.launch()
147
 
148
+