frankai98 commited on
Commit
68286e4
·
verified ·
1 Parent(s): fa9cb53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -18
app.py CHANGED
@@ -8,6 +8,9 @@ from streamlit.components.v1 import html
8
  import pandas as pd
9
  import torch
10
  import random
 
 
 
11
 
12
  # Retrieve the token from environment variables
13
  hf_token = os.environ.get("HF_TOKEN")
@@ -50,12 +53,26 @@ st.header("Review Scorer & Report Generator")
50
  # Concise introduction
51
  st.write("This model will score your reviews in your CSV file and generate a report based on your query and those results.")
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  # Load models with caching to avoid reloading on every run
54
  @st.cache_resource
55
  def load_models():
56
  llama_pipe = None
57
  score_pipe = None
58
- gemma_pipe = None
59
 
60
  try:
61
  st.info("Loading Llama 3.2 summarization model...")
@@ -64,6 +81,14 @@ def load_models():
64
  device=0, # Use GPU if available
65
  torch_dtype=torch.bfloat16,) # Use FP16 for efficiency
66
  st.success("Llama 3.2 summarization model loaded successfully!")
 
 
 
 
 
 
 
 
67
  except Exception as e:
68
  st.error(f"Error loading Llama 3.2 summarization model: {e}")
69
  st.error(f"Detailed error: {type(e).__name__}: {str(e)}")
@@ -74,26 +99,21 @@ def load_models():
74
  model="cardiffnlp/twitter-roberta-base-sentiment-latest",
75
  device=0 if torch.cuda.is_available() else -1)
76
  st.success("Sentiment analysis model loaded successfully!")
 
 
 
 
 
 
 
 
77
  except Exception as e:
78
  st.error(f"Error loading sentiment analysis model: {e}")
79
 
80
- try:
81
- st.info("Loading Gemma model...")
82
- # Load the tokenizer separately with the chat template
83
- tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
84
- gemma_pipe = pipeline("text-generation",
85
- model="google/gemma-3-1b-it",
86
- tokenizer=tokenizer, # Pass the loaded tokenizer here
87
- device=0,
88
- torch_dtype=torch.bfloat16)
89
- st.success("Gemma model loaded successfully!")
90
- except Exception as e:
91
- st.error(f"Error loading Gemma model: {e}")
92
- st.error(f"Detailed error: {type(e).__name__}: {str(e)}")
93
 
94
- return llama_pipe, score_pipe, gemma_pipe
95
 
96
- llama_pipe, score_pipe, gemma_pipe = load_models()
97
 
98
  def extract_assistant_content(raw_response):
99
  """Extract only the assistant's content from the Gemma-3 response."""
@@ -225,12 +245,27 @@ else:
225
  # Stage 2: Generate Report using Gemma in the new messages format.
226
  status_text.markdown("**📝 Generating report with Gemma...**")
227
 
228
- # Clear CUDA cache before using Gemma to free up memory
 
 
229
  torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  # Sample or summarize the data for Gemma to avoid memory issues
232
  import random
233
- max_reviews = 100 # Adjust based on your GPU memory
234
  if len(scored_docs) > max_reviews:
235
  sampled_docs = random.sample(scored_docs, max_reviews)
236
  st.info(f"Sampling {max_reviews} out of {len(scored_docs)} reviews for report generation")
 
8
  import pandas as pd
9
  import torch
10
  import random
11
+ import gc
12
+ import time
13
+ from threading import Thread
14
 
15
  # Retrieve the token from environment variables
16
  hf_token = os.environ.get("HF_TOKEN")
 
53
  # Concise introduction
54
  st.write("This model will score your reviews in your CSV file and generate a report based on your query and those results.")
55
 
56
+ def display_temp_message(message, message_type="info", duration=5):
57
+ """Display a temporary message that disappears after specified duration."""
58
+ placeholder = st.empty()
59
+
60
+ if message_type == "info":
61
+ placeholder.info(message)
62
+ elif message_type == "success":
63
+ placeholder.success(message)
64
+ elif message_type == "error":
65
+ placeholder.error(message)
66
+
67
+ # Wait and then clear the message
68
+ time.sleep(duration)
69
+ placeholder.empty()
70
+
71
  # Load models with caching to avoid reloading on every run
72
  @st.cache_resource
73
  def load_models():
74
  llama_pipe = None
75
  score_pipe = None
 
76
 
77
  try:
78
  st.info("Loading Llama 3.2 summarization model...")
 
81
  device=0, # Use GPU if available
82
  torch_dtype=torch.bfloat16,) # Use FP16 for efficiency
83
  st.success("Llama 3.2 summarization model loaded successfully!")
84
+
85
+ # Display success message that will disappear after 5 seconds
86
+ Thread(
87
+ target=display_temp_message,
88
+ args=("Llama 3.2 summarization model loaded successfully!", "success"),
89
+ daemon=True
90
+ ).start()
91
+
92
  except Exception as e:
93
  st.error(f"Error loading Llama 3.2 summarization model: {e}")
94
  st.error(f"Detailed error: {type(e).__name__}: {str(e)}")
 
99
  model="cardiffnlp/twitter-roberta-base-sentiment-latest",
100
  device=0 if torch.cuda.is_available() else -1)
101
  st.success("Sentiment analysis model loaded successfully!")
102
+
103
+ # Display success message that will disappear after 5 seconds
104
+ Thread(
105
+ target=display_temp_message,
106
+ args=("Sentiment analysis model loaded successfully!", "success"),
107
+ daemon=True
108
+ ).start()
109
+
110
  except Exception as e:
111
  st.error(f"Error loading sentiment analysis model: {e}")
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
+ return llama_pipe, score_pipe
115
 
116
+ llama_pipe, score_pipe = load_models()
117
 
118
  def extract_assistant_content(raw_response):
119
  """Extract only the assistant's content from the Gemma-3 response."""
 
245
  # Stage 2: Generate Report using Gemma in the new messages format.
246
  status_text.markdown("**📝 Generating report with Gemma...**")
247
 
248
+ # After using score_pipe
249
+ del score_pipe
250
+ gc.collect()
251
  torch.cuda.empty_cache()
252
+
253
+ # After using summarization_pipe
254
+ del llama_pipe
255
+ gc.collect()
256
+ torch.cuda.empty_cache()
257
+
258
+ # Then reload Gemma specifically for the final step
259
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")
260
+ gemma_pipe = pipeline("text-generation",
261
+ model="google/gemma-3-1b-it",
262
+ tokenizer=tokenizer,
263
+ device=0,
264
+ torch_dtype=torch.bfloat16)
265
 
266
  # Sample or summarize the data for Gemma to avoid memory issues
267
  import random
268
+ max_reviews = 50 # Adjust based on your GPU memory
269
  if len(scored_docs) > max_reviews:
270
  sampled_docs = random.sample(scored_docs, max_reviews)
271
  st.info(f"Sampling {max_reviews} out of {len(scored_docs)} reviews for report generation")