Rajan Sharma commited on
Commit
11a5624
·
verified ·
1 Parent(s): 76d5714

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -166
app.py CHANGED
@@ -1,33 +1,24 @@
1
- import shutil
2
  import os
 
 
 
3
 
4
- # Clear HuggingFace cache directory on every launch
5
- shutil.rmtree(os.path.expanduser("~/.cache/huggingface"), ignore_errors=True)
6
- shutil.rmtree("offload", ignore_errors=True) # Or whatever folder you use for offloading/cache
7
-
8
-
9
  import gradio as gr
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
- from datetime import datetime, timezone
12
- import os
13
  from huggingface_hub import login, HfApi
14
- from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError
15
- import time
16
- import requests
17
- from tenacity import retry, stop_after_attempt, wait_exponential
18
- from functools import lru_cache
19
- import torch
20
 
21
- # Global variables for model caching
22
- global_model = None
23
- global_tokenizer = None
 
 
24
 
25
  def get_timestamp():
26
- """Get current UTC datetime in specified format"""
27
- return datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%SS')
28
 
29
  def format_system_info(processing_time=None):
30
- """Format system information header"""
31
  info = (
32
  f"Current Date and Time (UTC - YYYY-MM-DD HH:MM:SS formatted): {get_timestamp()}\n"
33
  f"Current User's Login: Raj-VedAI\n"
@@ -36,170 +27,123 @@ def format_system_info(processing_time=None):
36
  info += f"Processing Time: {processing_time:.2f} seconds\n"
37
  return info
38
 
 
 
 
 
 
 
 
 
39
  @lru_cache(maxsize=1)
40
  def load_model():
41
- """Load and cache the model"""
42
- global global_model, global_tokenizer
43
-
44
- if global_model is not None and global_tokenizer is not None:
45
- return global_model, global_tokenizer
46
-
47
- try:
48
- token = os.getenv("HUGGING_FACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
49
- if not token:
50
- raise ValueError("No token found. Please set HUGGING_FACE_HUB_TOKEN or HF_TOKEN in Space secrets.")
51
-
52
- login(token=token, add_to_git_credential=False)
53
-
54
- model_id = "CohereLabs/c4ai-command-a-03-2025"
55
-
56
- # Load tokenizer with optimizations
57
- tokenizer = AutoTokenizer.from_pretrained(
58
- model_id,
59
- token=token,
60
- use_fast=True,
61
- model_max_length=2048
62
- )
63
-
64
- # Load model with optimizations
65
- model = AutoModelForCausalLM.from_pretrained(
66
- model_id,
67
- token=token,
68
- device_map="auto",
69
- low_cpu_mem_usage=True,
70
- torch_dtype=torch.float16, # Use float16 for faster inference
71
- offload_folder="offload" # Enable model offloading if needed
72
- )
73
-
74
- global_model = model
75
- global_tokenizer = tokenizer
76
- return model, tokenizer
77
-
78
- except Exception as e:
79
- raise Exception(f"Error loading model: {str(e)}")
80
 
81
- def generate_with_timeout(model, input_ids, max_new_tokens=100, timeout=60):
82
- """Generate response with timeout"""
83
- try:
84
- with torch.no_grad():
85
- output = model.generate(
86
- input_ids,
87
- max_new_tokens=max_new_tokens,
88
- do_sample=True,
89
- temperature=0.3,
90
- pad_token_id=model.config.eos_token_id,
91
- attention_mask=input_ids.new_ones(input_ids.shape),
92
- top_p=0.9,
93
- repetition_penalty=1.2,
94
- timeout_seconds=timeout
95
- )
96
- return output
97
- except Exception as e:
98
- raise Exception(f"Generation timeout or error: {str(e)}")
99
 
100
- @retry(stop=stop_after_attempt(2), wait=wait_exponential(multiplier=1, min=2, max=4))
101
- def chat(message, history):
102
- start_time = time.time()
103
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  try:
105
- # Load or get cached model
106
  model, tokenizer = load_model()
107
-
108
- if history is None:
109
- history = []
110
-
111
- # Format messages
112
- messages = [{"role": "user", "content": message}]
113
- input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
114
-
115
- # Generate response with timeout
116
- gen_tokens = generate_with_timeout(model, input_ids)
117
-
118
- # Decode response
119
- gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
120
-
121
- # Calculate processing time
122
- processing_time = time.time() - start_time
123
- system_info = format_system_info(processing_time)
124
-
125
- # Format response
126
- history.append({"role": "user", "content": message})
127
- history.append({"role": "assistant", "content": f"{system_info}{gen_text}"})
128
- return history
129
-
130
  except Exception as e:
131
- processing_time = time.time() - start_time
132
- system_info = format_system_info(processing_time)
133
- error_msg = f"{system_info}Error during chat: {str(e)}\nAttempting reconnection..."
134
-
135
- if history is None:
136
- history = []
137
- history.append({"role": "user", "content": message})
138
- history.append({"role": "assistant", "content": error_msg})
139
- return history
140
 
141
  def check_connection():
142
  try:
143
- token = os.getenv("HUGGING_FACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
144
- api = HfApi(token=token)
145
- model_info = api.model_info("CohereLabs/c4ai-command-a-03-2025")
146
- return f"""
147
- {format_system_info()}
148
- Connection Status: ✅ Connected
149
- Model: {model_info.modelId}
150
- Last Modified: {model_info.lastModified}
151
- Model Status: {'Loaded' if global_model is not None else 'Not Loaded'}
152
- """
153
  except Exception as e:
154
- return f"{format_system_info()}Connection Status: ❌ Error\nDetails: {str(e)}"
155
 
156
- # Create the Gradio interface with loading indicator
157
  with gr.Blocks(theme=gr.themes.Default()) as demo:
158
  gr.Markdown(f"# Medical Decision Support AI\n{format_system_info()}")
159
-
160
- with gr.Row():
161
- connection_btn = gr.Button("Check Connection Status")
162
- connection_status = gr.Textbox(label="Connection Status", lines=6)
163
-
164
- # Add loading configuration
165
  with gr.Row():
166
- gr.Markdown("⚙️ Model is loading... Please wait for first response.")
167
-
168
- chat_interface = gr.ChatInterface(
169
- fn=chat,
170
- description=f"A medical decision support system that provides healthcare-related information and guidance.\n{format_system_info()}",
 
 
 
171
  examples=[
172
  "What are the symptoms of hypertension?",
173
  "What are common drug interactions with aspirin?",
174
  "What are the warning signs of diabetes?",
175
  ],
176
- # Buttons below are not valid in Gradio 4.x+:
177
- # retry_btn="Retry ↺",
178
- # undo_btn="Undo ↶",
179
- # clear_btn="Clear 🗑️"
180
- # type='messages'
181
- # To customize buttons, see: https://www.gradio.app/docs/chatinterface/
182
  )
183
-
184
- connection_btn.click(check_connection, outputs=connection_status)
185
-
186
- # Check connection and load model on startup
187
- connection_status.value = check_connection()
188
- # Pre-load the model
189
- try:
190
- load_model()
191
- except Exception as e:
192
- gr.Warning(f"Model pre-loading failed: {str(e)}")
193
-
194
- # Update requirements
195
- requirements = """
196
- gradio>=3.50.2
197
- transformers
198
- torch
199
- accelerate
200
- huggingface_hub
201
- requests
202
- tenacity
203
- """
204
-
205
- demo.launch()
 
1
+ # app.py
2
  import os
3
+ import time
4
+ from datetime import datetime, timezone
5
+ from functools import lru_cache
6
 
7
+ import torch
 
 
 
 
8
  import gradio as gr
9
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
10
  from huggingface_hub import login, HfApi
 
 
 
 
 
 
11
 
12
+ MODEL_ID = os.getenv("MODEL_ID", "CohereLabs/c4ai-command-a-03-2025") # change if needed
13
+ HF_TOKEN = (
14
+ os.getenv("HUGGINGFACE_HUB_TOKEN") # <-- correct canonical name
15
+ or os.getenv("HF_TOKEN")
16
+ )
17
 
18
  def get_timestamp():
19
+ return datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
 
20
 
21
  def format_system_info(processing_time=None):
 
22
  info = (
23
  f"Current Date and Time (UTC - YYYY-MM-DD HH:MM:SS formatted): {get_timestamp()}\n"
24
  f"Current User's Login: Raj-VedAI\n"
 
27
  info += f"Processing Time: {processing_time:.2f} seconds\n"
28
  return info
29
 
30
+ def _pick_dtype_and_map():
31
+ if torch.cuda.is_available():
32
+ return torch.float16, "auto"
33
+ if torch.backends.mps.is_available():
34
+ # Apple Silicon (MPS) prefers float16/bfloat16 depending on model; float16 is usually OK.
35
+ return torch.float16, {"": "mps"}
36
+ return torch.float32, "cpu" # CPU-safe
37
+
38
  @lru_cache(maxsize=1)
39
  def load_model():
40
+ if HF_TOKEN:
41
+ # In Spaces this isn’t strictly necessary if the secret is set, but it doesn’t hurt.
42
+ login(token=HF_TOKEN, add_to_git_credential=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ dtype, device_map = _pick_dtype_and_map()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ tok = AutoTokenizer.from_pretrained(
47
+ MODEL_ID,
48
+ token=HF_TOKEN,
49
+ use_fast=True,
50
+ model_max_length=4096,
51
+ padding_side="left", # safer for some chat templates
52
+ )
53
+
54
+ mdl = AutoModelForCausalLM.from_pretrained(
55
+ MODEL_ID,
56
+ token=HF_TOKEN,
57
+ device_map=device_map,
58
+ low_cpu_mem_usage=True,
59
+ torch_dtype=dtype,
60
+ )
61
+
62
+ # Fallback for models without an EOS defined
63
+ if mdl.config.eos_token_id is None and tok.eos_token_id is not None:
64
+ mdl.config.eos_token_id = tok.eos_token_id
65
+
66
+ return mdl, tok
67
+
68
+ def build_inputs(tokenizer, message, history):
69
+ # Convert Gradio’s (message, history) into a chat template
70
+ msgs = []
71
+ # Optionally carry past turns if your model supports it
72
+ for u, a in history or []:
73
+ msgs.append({"role": "user", "content": u})
74
+ msgs.append({"role": "assistant", "content": a})
75
+ msgs.append({"role": "user", "content": message})
76
+ inputs = tokenizer.apply_chat_template(
77
+ msgs,
78
+ tokenize=True,
79
+ add_generation_prompt=True,
80
+ return_tensors="pt",
81
+ )
82
+ return inputs
83
+
84
+ def generate_reply(model, tokenizer, input_ids, max_new_tokens=256):
85
+ input_ids = input_ids.to(model.device)
86
+ with torch.no_grad():
87
+ out = model.generate(
88
+ input_ids=input_ids,
89
+ max_new_tokens=max_new_tokens,
90
+ do_sample=True,
91
+ temperature=0.3,
92
+ top_p=0.9,
93
+ repetition_penalty=1.2,
94
+ pad_token_id=tokenizer.eos_token_id,
95
+ eos_token_id=tokenizer.eos_token_id,
96
+ )
97
+ # Slice off the prompt so we only return new tokens
98
+ gen_only = out[0, input_ids.shape[-1]:]
99
+ text = tokenizer.decode(gen_only, skip_special_tokens=True)
100
+ return text.strip()
101
+
102
+ def chat_fn(message, history):
103
+ start = time.time()
104
  try:
 
105
  model, tokenizer = load_model()
106
+ inputs = build_inputs(tokenizer, message, history)
107
+ reply = generate_reply(model, tokenizer, inputs, max_new_tokens=300)
108
+ # Optional: prepend system info once per turn
109
+ reply = f"{format_system_info(time.time() - start)}{reply}"
110
+ return reply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  except Exception as e:
112
+ return f"{format_system_info(time.time() - start)}Error during chat: {e}"
 
 
 
 
 
 
 
 
113
 
114
  def check_connection():
115
  try:
116
+ api = HfApi(token=HF_TOKEN)
117
+ mi = api.model_info(MODEL_ID)
118
+ return (
119
+ f"{format_system_info()}"
120
+ f"Connection Status: ✅ Connected\n"
121
+ f"Model: {mi.modelId}\n"
122
+ f"Last Modified: {mi.lastModified}\n"
123
+ )
 
 
124
  except Exception as e:
125
+ return f"{format_system_info()}Connection Status: ❌ Error\nDetails: {e}"
126
 
 
127
  with gr.Blocks(theme=gr.themes.Default()) as demo:
128
  gr.Markdown(f"# Medical Decision Support AI\n{format_system_info()}")
 
 
 
 
 
 
129
  with gr.Row():
130
+ btn = gr.Button("Check Connection Status")
131
+ status = gr.Textbox(label="Connection Status", lines=6, value="Click to check…")
132
+ gr.Markdown("⚙️ Model is loading on first request. Please wait for the first answer.")
133
+
134
+ chat = gr.ChatInterface(
135
+ fn=chat_fn,
136
+ type="messages", # use the modern message format
137
+ description="A medical decision support system that provides healthcare-related information and guidance.",
138
  examples=[
139
  "What are the symptoms of hypertension?",
140
  "What are common drug interactions with aspirin?",
141
  "What are the warning signs of diabetes?",
142
  ],
 
 
 
 
 
 
143
  )
144
+
145
+ btn.click(fn=check_connection, outputs=status)
146
+
147
+ if __name__ == "__main__":
148
+ demo.launch()
149
+