Rajan Sharma commited on
Commit
d877f27
·
verified ·
1 Parent(s): b86ef2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -36
app.py CHANGED
@@ -7,86 +7,123 @@ from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError
7
  import time
8
  import requests
9
  from tenacity import retry, stop_after_attempt, wait_exponential
 
 
 
 
 
 
10
 
11
  def get_timestamp():
12
  """Get current UTC datetime in specified format"""
13
  return datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%SS')
14
 
15
- def format_system_info():
16
  """Format system information header"""
17
- return (
18
  f"Current Date and Time (UTC - YYYY-MM-DD HH:MM:SS formatted): {get_timestamp()}\n"
19
  f"Current User's Login: Raj-VedAI\n"
20
  )
 
 
 
21
 
22
- # Add retry decorator for connection attempts
23
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
24
- def initialize_model():
 
 
 
 
 
25
  try:
26
- # Try HUGGING_FACE_HUB_TOKEN first, fallback to HF_TOKEN
27
  token = os.getenv("HUGGING_FACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
28
  if not token:
29
- return False, "No token found. Please set HUGGING_FACE_HUB_TOKEN or HF_TOKEN in Space secrets.", None
30
 
31
- # Force re-login to refresh connection
32
  login(token=token, add_to_git_credential=False)
33
 
34
- # Initialize with device mapping and low memory settings
35
  model_id = "CohereLabs/c4ai-command-a-03-2025"
 
 
36
  tokenizer = AutoTokenizer.from_pretrained(
37
  model_id,
38
  token=token,
39
- use_fast=True
 
40
  )
 
 
41
  model = AutoModelForCausalLM.from_pretrained(
42
  model_id,
43
  token=token,
44
  device_map="auto",
45
  low_cpu_mem_usage=True,
46
- torch_dtype="auto"
 
47
  )
48
- return True, model, tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  except Exception as e:
50
- return False, f"Error during initialization: {str(e)}", None
51
 
52
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
53
  def chat(message, history):
54
- system_info = format_system_info()
55
 
56
  try:
57
- # Initialize model if not already done
58
- success, result, tokenizer = initialize_model()
59
- if not success:
60
- return [{"role": "user", "content": message},
61
- {"role": "assistant", "content": f"{system_info}Error: {result}"}]
62
- model = result
63
 
64
  if history is None:
65
  history = []
66
 
67
- # Format messages with the chat template
68
  messages = [{"role": "user", "content": message}]
69
- input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)
70
 
71
- # Generate response with safety settings
72
- gen_tokens = model.generate(
73
- input_ids,
74
- max_new_tokens=100,
75
- do_sample=True,
76
- temperature=0.3,
77
- pad_token_id=tokenizer.eos_token_id,
78
- attention_mask=input_ids.new_ones(input_ids.shape)
79
- )
80
 
81
  # Decode response
82
  gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
83
 
84
- # Format response using new message format
 
 
 
 
85
  history.append({"role": "user", "content": message})
86
  history.append({"role": "assistant", "content": f"{system_info}{gen_text}"})
87
  return history
 
88
  except Exception as e:
 
 
89
  error_msg = f"{system_info}Error during chat: {str(e)}\nAttempting reconnection..."
 
90
  if history is None:
91
  history = []
92
  history.append({"role": "user", "content": message})
@@ -103,11 +140,12 @@ def check_connection():
103
  Connection Status: ✅ Connected
104
  Model: {model_info.modelId}
105
  Last Modified: {model_info.lastModified}
 
106
  """
107
  except Exception as e:
108
  return f"{format_system_info()}Connection Status: ❌ Error\nDetails: {str(e)}"
109
 
110
- # Create the Gradio interface with connection monitoring
111
  with gr.Blocks(theme=gr.themes.Default()) as demo:
112
  gr.Markdown(f"# Medical Decision Support AI\n{format_system_info()}")
113
 
@@ -115,6 +153,10 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
115
  connection_btn = gr.Button("Check Connection Status")
116
  connection_status = gr.Textbox(label="Connection Status", lines=6)
117
 
 
 
 
 
118
  chat_interface = gr.ChatInterface(
119
  fn=chat,
120
  description=f"A medical decision support system that provides healthcare-related information and guidance.\n{format_system_info()}",
@@ -123,12 +165,31 @@ with gr.Blocks(theme=gr.themes.Default()) as demo:
123
  "What are common drug interactions with aspirin?",
124
  "What are the warning signs of diabetes?",
125
  ],
126
- type='messages' # Using new message format
 
 
 
127
  )
128
 
129
  connection_btn.click(check_connection, outputs=connection_status)
130
 
131
- # Check connection on startup
132
  connection_status.value = check_connection()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  demo.launch()
 
7
  import time
8
  import requests
9
  from tenacity import retry, stop_after_attempt, wait_exponential
10
+ from functools import lru_cache
11
+ import torch
12
+
13
+ # Global variables for model caching
14
+ global_model = None
15
+ global_tokenizer = None
16
 
17
  def get_timestamp():
18
  """Get current UTC datetime in specified format"""
19
  return datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%SS')
20
 
21
+ def format_system_info(processing_time=None):
22
  """Format system information header"""
23
+ info = (
24
  f"Current Date and Time (UTC - YYYY-MM-DD HH:MM:SS formatted): {get_timestamp()}\n"
25
  f"Current User's Login: Raj-VedAI\n"
26
  )
27
+ if processing_time is not None:
28
+ info += f"Processing Time: {processing_time:.2f} seconds\n"
29
+ return info
30
 
31
+ @lru_cache(maxsize=1)
32
+ def load_model():
33
+ """Load and cache the model"""
34
+ global global_model, global_tokenizer
35
+
36
+ if global_model is not None and global_tokenizer is not None:
37
+ return global_model, global_tokenizer
38
+
39
  try:
 
40
  token = os.getenv("HUGGING_FACE_HUB_TOKEN") or os.getenv("HF_TOKEN")
41
  if not token:
42
+ raise ValueError("No token found. Please set HUGGING_FACE_HUB_TOKEN or HF_TOKEN in Space secrets.")
43
 
 
44
  login(token=token, add_to_git_credential=False)
45
 
 
46
  model_id = "CohereLabs/c4ai-command-a-03-2025"
47
+
48
+ # Load tokenizer with optimizations
49
  tokenizer = AutoTokenizer.from_pretrained(
50
  model_id,
51
  token=token,
52
+ use_fast=True,
53
+ model_max_length=2048
54
  )
55
+
56
+ # Load model with optimizations
57
  model = AutoModelForCausalLM.from_pretrained(
58
  model_id,
59
  token=token,
60
  device_map="auto",
61
  low_cpu_mem_usage=True,
62
+ torch_dtype=torch.float16, # Use float16 for faster inference
63
+ offload_folder="offload" # Enable model offloading if needed
64
  )
65
+
66
+ global_model = model
67
+ global_tokenizer = tokenizer
68
+ return model, tokenizer
69
+
70
+ except Exception as e:
71
+ raise Exception(f"Error loading model: {str(e)}")
72
+
73
+ def generate_with_timeout(model, input_ids, max_new_tokens=100, timeout=60):
74
+ """Generate response with timeout"""
75
+ try:
76
+ with torch.no_grad():
77
+ output = model.generate(
78
+ input_ids,
79
+ max_new_tokens=max_new_tokens,
80
+ do_sample=True,
81
+ temperature=0.3,
82
+ pad_token_id=model.config.eos_token_id,
83
+ attention_mask=input_ids.new_ones(input_ids.shape),
84
+ top_p=0.9,
85
+ repetition_penalty=1.2,
86
+ timeout_seconds=timeout
87
+ )
88
+ return output
89
  except Exception as e:
90
+ raise Exception(f"Generation timeout or error: {str(e)}")
91
 
92
+ @retry(stop=stop_after_attempt(2), wait=wait_exponential(multiplier=1, min=2, max=4))
93
  def chat(message, history):
94
+ start_time = time.time()
95
 
96
  try:
97
+ # Load or get cached model
98
+ model, tokenizer = load_model()
 
 
 
 
99
 
100
  if history is None:
101
  history = []
102
 
103
+ # Format messages
104
  messages = [{"role": "user", "content": message}]
105
+ input_ids = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
106
 
107
+ # Generate response with timeout
108
+ gen_tokens = generate_with_timeout(model, input_ids)
 
 
 
 
 
 
 
109
 
110
  # Decode response
111
  gen_text = tokenizer.decode(gen_tokens[0], skip_special_tokens=True)
112
 
113
+ # Calculate processing time
114
+ processing_time = time.time() - start_time
115
+ system_info = format_system_info(processing_time)
116
+
117
+ # Format response
118
  history.append({"role": "user", "content": message})
119
  history.append({"role": "assistant", "content": f"{system_info}{gen_text}"})
120
  return history
121
+
122
  except Exception as e:
123
+ processing_time = time.time() - start_time
124
+ system_info = format_system_info(processing_time)
125
  error_msg = f"{system_info}Error during chat: {str(e)}\nAttempting reconnection..."
126
+
127
  if history is None:
128
  history = []
129
  history.append({"role": "user", "content": message})
 
140
  Connection Status: ✅ Connected
141
  Model: {model_info.modelId}
142
  Last Modified: {model_info.lastModified}
143
+ Model Status: {'Loaded' if global_model is not None else 'Not Loaded'}
144
  """
145
  except Exception as e:
146
  return f"{format_system_info()}Connection Status: ❌ Error\nDetails: {str(e)}"
147
 
148
+ # Create the Gradio interface with loading indicator
149
  with gr.Blocks(theme=gr.themes.Default()) as demo:
150
  gr.Markdown(f"# Medical Decision Support AI\n{format_system_info()}")
151
 
 
153
  connection_btn = gr.Button("Check Connection Status")
154
  connection_status = gr.Textbox(label="Connection Status", lines=6)
155
 
156
+ # Add loading configuration
157
+ with gr.Row():
158
+ gr.Markdown("⚙️ Model is loading... Please wait for first response.")
159
+
160
  chat_interface = gr.ChatInterface(
161
  fn=chat,
162
  description=f"A medical decision support system that provides healthcare-related information and guidance.\n{format_system_info()}",
 
165
  "What are common drug interactions with aspirin?",
166
  "What are the warning signs of diabetes?",
167
  ],
168
+ type='messages',
169
+ retry_btn="Retry ↺",
170
+ undo_btn="Undo ↶",
171
+ clear_btn="Clear 🗑️"
172
  )
173
 
174
  connection_btn.click(check_connection, outputs=connection_status)
175
 
176
+ # Check connection and load model on startup
177
  connection_status.value = check_connection()
178
+ # Pre-load the model
179
+ try:
180
+ load_model()
181
+ except Exception as e:
182
+ gr.Warning(f"Model pre-loading failed: {str(e)}")
183
+
184
+ # Update requirements
185
+ requirements = """
186
+ gradio>=3.50.2
187
+ transformers
188
+ torch
189
+ accelerate
190
+ huggingface_hub
191
+ requests
192
+ tenacity
193
+ """
194
 
195
  demo.launch()