Sooteemon commited on
Commit
b5400ea
·
verified ·
1 Parent(s): 1798b32

Update sentiment_analyzer.py

Browse files
Files changed (1) hide show
  1. sentiment_analyzer.py +28 -66
sentiment_analyzer.py CHANGED
@@ -2,7 +2,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
  import torch
3
  import re
4
 
5
- class NewsAnalyzer: # --- MODIFIED: Renamed class ---
6
  def __init__(self, model_name="google/gemma-2-2b-it"):
7
  """
8
  Initialize news analyzer with Gemma model
@@ -28,38 +28,26 @@ class NewsAnalyzer: # --- MODIFIED: Renamed class ---
28
 
29
  except Exception as e:
30
  print(f"Error loading model: {e}")
31
- # Fallback to sentiment pipeline
32
  self.model = None
33
  self.sentiment_pipeline = pipeline(
34
  "sentiment-analysis",
35
  model="distilbert-base-uncased-finetuned-sst-2-english"
36
  )
37
 
38
- def analyze_news_item(self, text): # --- MODIFIED: Renamed function ---
39
  """
40
  วิเคราะห์ข่าว (Sentiment, Theme, Impact)
41
-
42
- Args:
43
- text: ข้อความที่ต้องการวิเคราะห์
44
-
45
- Returns:
46
- dict: {sentiment, score, theme, impact, explanation}
47
  """
48
  if not text or len(text.strip()) == 0:
49
  return {
50
- "sentiment": "Neutral",
51
- "score": 0.5,
52
- "theme": "Other",
53
- "impact": "Neutral",
54
- "explanation": "No text to analyze"
55
  }
56
 
57
- # ถ้า model โหลดไม่สำเร็จ ใช้ fallback pipeline
58
  if self.model is None:
59
  return self._fallback_sentiment(text)
60
 
61
  try:
62
- # --- MODIFIED: New comprehensive prompt ---
63
  prompt = f"""Analyze this financial news article. Provide your analysis in the *exact* format specified below.
64
 
65
  **Categories to use:**
@@ -77,128 +65,102 @@ Theme: [Selected Theme]
77
  Impact: [Selected Impact]
78
  Reason: [Brief explanation of your analysis]"""
79
 
80
- # Tokenize และ generate
81
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
82
  inputs = inputs.to(self.device)
 
 
 
83
 
84
  with torch.no_grad():
85
  outputs = self.model.generate(
86
  **inputs,
87
- max_new_tokens=200, # Increased tokens for longer response
88
  temperature=0.3,
89
  do_sample=True,
90
  pad_token_id=self.tokenizer.eos_token_id
91
  )
92
 
93
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
94
 
95
- # Parse response
96
- return self._parse_llm_analysis(response) # --- MODIFIED ---
97
 
98
  except Exception as e:
99
  print(f"Error in analysis: {e}")
100
  return self._fallback_sentiment(text)
101
 
102
- def _parse_llm_analysis(self, response): # --- MODIFIED: Renamed and updated parser ---
103
  """แยก sentiment, score, theme, impact และ explanation จาก LLM response"""
104
  sentiment = "Neutral"
105
  score = 0.5
106
- theme = "Other" # --- ADDED ---
107
- impact = "Neutral" # --- ADDED ---
108
- explanation = "Unable to analyze"
109
 
110
  try:
111
- # Extract sentiment
112
  sentiment_line = re.search(r'Sentiment:\s*(\w+)', response, re.IGNORECASE)
113
  if sentiment_line:
114
  sentiment = sentiment_line.group(1).capitalize()
115
 
116
- # Extract score
117
  score_line = re.search(r'Score:\s*([\d.]+)', response)
118
  if score_line:
119
  score = float(score_line.group(1))
120
- score = max(0.0, min(1.0, score)) # Clamp between 0-1
121
 
122
- # --- ADDED: Extract Theme ---
123
  theme_line = re.search(r'Theme:\s*([\w\/ -]+)', response, re.IGNORECASE)
124
  if theme_line:
125
  theme = theme_line.group(1).strip()
126
 
127
- # --- ADDED: Extract Impact ---
128
  impact_line = re.search(r'Impact:\s*(\w+)', response, re.IGNORECASE)
129
  if impact_line:
130
  impact = impact_line.group(1).capitalize().strip()
131
 
132
- # Extract reason/explanation
133
- reason_match = re.search(r'Reason:\s*(.+?)(?:\n|$)', response, re.DOTALL | re.IGNORECASE)
134
  if reason_match:
135
  explanation = reason_match.group(1).strip()
 
136
 
137
- # Validate sentiment
138
  if sentiment not in ["Positive", "Negative", "Neutral"]:
139
  sentiment = "Neutral"
140
-
141
- # Validate impact
142
  if impact not in ["Opportunity", "Risk", "Neutral"]:
143
  impact = "Neutral"
144
 
145
  except Exception as e:
146
- print(f"Parse error: {e}")
147
 
148
  return {
149
- "sentiment": sentiment,
150
- "score": score,
151
- "theme": theme,
152
- "impact": impact,
153
- "explanation": explanation
154
  }
155
 
156
  def _fallback_sentiment(self, text):
157
  """Fallback method ใช้ DistilBERT"""
158
  try:
159
  result = self.sentiment_pipeline(text[:512])[0]
160
-
161
- # Convert to our format
162
  sentiment = "Positive" if result['label'] == 'POSITIVE' else "Negative"
163
  score = result['score']
164
-
165
  return {
166
- "sentiment": sentiment,
167
- "score": score,
168
- "theme": "N/A", # --- ADDED ---
169
- "impact": "N/A", # --- ADDED ---
170
- "explanation": f"Analyzed using fallback model with {score:.2%} confidence"
171
  }
172
  except:
173
  return {
174
- "sentiment": "Neutral",
175
- "score": 0.5,
176
- "theme": "N/A", # --- ADDED ---
177
- "impact": "N/A", # --- ADDED ---
178
- "explanation": "Analysis unavailable"
179
  }
180
 
181
  def analyze_batch(self, news_list):
182
  """
183
  วิเคราะห์ sentiment หลายข่าวพร้อมกัน
184
-
185
- Args:
186
- news_list: list ของ dict ที่มี title และ summary
187
-
188
- Returns:
189
- list: รายการผลการวิเคราะห์
190
  """
191
  results = []
192
-
193
  for news in news_list:
194
- # รวม title และ summary
195
  combined_text = f"{news.get('title', '')} {news.get('summary', '')}"
196
-
197
- sentiment_result = self.analyze_news_item(combined_text) # --- MODIFIED ---
198
-
199
  results.append({
200
  **news,
201
  **sentiment_result
202
  })
203
-
204
  return results
 
2
  import torch
3
  import re
4
 
5
+ class NewsAnalyzer:
6
  def __init__(self, model_name="google/gemma-2-2b-it"):
7
  """
8
  Initialize news analyzer with Gemma model
 
28
 
29
  except Exception as e:
30
  print(f"Error loading model: {e}")
 
31
  self.model = None
32
  self.sentiment_pipeline = pipeline(
33
  "sentiment-analysis",
34
  model="distilbert-base-uncased-finetuned-sst-2-english"
35
  )
36
 
37
+ def analyze_news_item(self, text):
38
  """
39
  วิเคราะห์ข่าว (Sentiment, Theme, Impact)
 
 
 
 
 
 
40
  """
41
  if not text or len(text.strip()) == 0:
42
  return {
43
+ "sentiment": "Neutral", "score": 0.5, "theme": "Other",
44
+ "impact": "Neutral", "explanation": "No text to analyze"
 
 
 
45
  }
46
 
 
47
  if self.model is None:
48
  return self._fallback_sentiment(text)
49
 
50
  try:
 
51
  prompt = f"""Analyze this financial news article. Provide your analysis in the *exact* format specified below.
52
 
53
  **Categories to use:**
 
65
  Impact: [Selected Impact]
66
  Reason: [Brief explanation of your analysis]"""
67
 
 
68
  inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
69
  inputs = inputs.to(self.device)
70
+
71
+ # --- MODIFIED: Get prompt length to slice output correctly ---
72
+ prompt_length = inputs['input_ids'].shape[1]
73
 
74
  with torch.no_grad():
75
  outputs = self.model.generate(
76
  **inputs,
77
+ max_new_tokens=200,
78
  temperature=0.3,
79
  do_sample=True,
80
  pad_token_id=self.tokenizer.eos_token_id
81
  )
82
 
83
+ # --- MODIFIED: Decode *only* the new tokens, not the prompt ---
84
+ new_tokens = outputs[0][prompt_length:]
85
+ response = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
86
 
87
+ return self._parse_llm_analysis(response)
 
88
 
89
  except Exception as e:
90
  print(f"Error in analysis: {e}")
91
  return self._fallback_sentiment(text)
92
 
93
+ def _parse_llm_analysis(self, response):
94
  """แยก sentiment, score, theme, impact และ explanation จาก LLM response"""
95
  sentiment = "Neutral"
96
  score = 0.5
97
+ theme = "Other"
98
+ impact = "Neutral"
99
+ explanation = "Unable to parse" # Default explanation if parse fails
100
 
101
  try:
 
102
  sentiment_line = re.search(r'Sentiment:\s*(\w+)', response, re.IGNORECASE)
103
  if sentiment_line:
104
  sentiment = sentiment_line.group(1).capitalize()
105
 
 
106
  score_line = re.search(r'Score:\s*([\d.]+)', response)
107
  if score_line:
108
  score = float(score_line.group(1))
109
+ score = max(0.0, min(1.0, score))
110
 
 
111
  theme_line = re.search(r'Theme:\s*([\w\/ -]+)', response, re.IGNORECASE)
112
  if theme_line:
113
  theme = theme_line.group(1).strip()
114
 
 
115
  impact_line = re.search(r'Impact:\s*(\w+)', response, re.IGNORECASE)
116
  if impact_line:
117
  impact = impact_line.group(1).capitalize().strip()
118
 
119
+ # --- MODIFIED: More robust regex for Reason (captures multi-line) ---
120
+ reason_match = re.search(r'Reason:\s*(.*)', response, re.DOTALL | re.IGNORECASE)
121
  if reason_match:
122
  explanation = reason_match.group(1).strip()
123
+ # If parsing fails, explanation will remain "Unable to parse" or the last good value
124
 
 
125
  if sentiment not in ["Positive", "Negative", "Neutral"]:
126
  sentiment = "Neutral"
 
 
127
  if impact not in ["Opportunity", "Risk", "Neutral"]:
128
  impact = "Neutral"
129
 
130
  except Exception as e:
131
+ print(f"Parse error: {e}. Response was: {response}")
132
 
133
  return {
134
+ "sentiment": sentiment, "score": score, "theme": theme,
135
+ "impact": impact, "explanation": explanation
 
 
 
136
  }
137
 
138
  def _fallback_sentiment(self, text):
139
  """Fallback method ใช้ DistilBERT"""
140
  try:
141
  result = self.sentiment_pipeline(text[:512])[0]
 
 
142
  sentiment = "Positive" if result['label'] == 'POSITIVE' else "Negative"
143
  score = result['score']
 
144
  return {
145
+ "sentiment": sentiment, "score": score, "theme": "N/A",
146
+ "impact": "N/A", "explanation": f"Analyzed using fallback model"
 
 
 
147
  }
148
  except:
149
  return {
150
+ "sentiment": "Neutral", "score": 0.5, "theme": "N/A",
151
+ "impact": "N/A", "explanation": "Analysis unavailable"
 
 
 
152
  }
153
 
154
  def analyze_batch(self, news_list):
155
  """
156
  วิเคราะห์ sentiment หลายข่าวพร้อมกัน
 
 
 
 
 
 
157
  """
158
  results = []
 
159
  for news in news_list:
 
160
  combined_text = f"{news.get('title', '')} {news.get('summary', '')}"
161
+ sentiment_result = self.analyze_news_item(combined_text)
 
 
162
  results.append({
163
  **news,
164
  **sentiment_result
165
  })
 
166
  return results