s880453 commited on
Commit
3269285
·
verified ·
1 Parent(s): 2d71652

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -12
app.py CHANGED
@@ -121,26 +121,35 @@ def create_plot(df, chart_type, x_column, y_column, group_column=None, size_colu
121
 
122
  elif chart_type == "堆疊長條圖":
123
  if group_column and group_column in df.columns:
124
- grouped_df = df.pivot_table(index=x_column, columns=group_column,
125
- values=y_column, aggfunc=agg_func).reset_index()
126
- grouped_df = grouped_df.fillna(0)
127
 
128
- # 取得所有類別
129
- categories = grouped_df.columns.tolist()
 
 
 
 
 
 
 
 
130
  categories.remove(x_column)
131
 
 
132
  for i, category in enumerate(categories):
133
  color = colors[i % len(colors)]
134
- if category in custom_colors:
135
- color = custom_colors[category]
136
 
137
  pattern_shape = None
138
  if patterns and i < len(patterns) and patterns[i] != "無":
139
  pattern_shape = patterns[i]
140
 
141
  fig.add_trace(go.Bar(
142
- x=grouped_df[x_column],
143
- y=grouped_df[category],
144
  name=str(category),
145
  marker_color=color,
146
  marker_pattern_shape=pattern_shape
@@ -148,9 +157,10 @@ def create_plot(df, chart_type, x_column, y_column, group_column=None, size_colu
148
 
149
  fig.update_layout(barmode='stack')
150
  else:
151
- grouped_df = df.groupby(x_column)[y_column].agg(agg_func).reset_index()
152
- fig = px.bar(grouped_df, x=x_column, y=y_column, **fig_params)
153
-
 
154
 
155
  elif chart_type == "群組長條圖":
156
  if group_column and group_column in df.columns:
 
121
 
122
  elif chart_type == "堆疊長條圖":
123
  if group_column and group_column in df.columns:
124
+ # 明確將字符串列轉換為類別型
125
+ df[x_column] = df[x_column].astype('category')
126
+ df[group_column] = df[group_column].astype('category')
127
 
128
+ # 先進行計數統計
129
+ count_df = df.groupby([x_column, group_column]).size().reset_index(name='count')
130
+
131
+ # 創建樞紐表
132
+ pivot_df = count_df.pivot_table(index=x_column, columns=group_column,
133
+ values='count', aggfunc='sum').reset_index()
134
+ pivot_df = pivot_df.fillna(0)
135
+
136
+ # 獲取所有情緒類別
137
+ categories = pivot_df.columns.tolist()
138
  categories.remove(x_column)
139
 
140
+ # 創建圖表
141
  for i, category in enumerate(categories):
142
  color = colors[i % len(colors)]
143
+ if str(category) in custom_colors:
144
+ color = custom_colors[str(category)]
145
 
146
  pattern_shape = None
147
  if patterns and i < len(patterns) and patterns[i] != "無":
148
  pattern_shape = patterns[i]
149
 
150
  fig.add_trace(go.Bar(
151
+ x=pivot_df[x_column],
152
+ y=pivot_df[category],
153
  name=str(category),
154
  marker_color=color,
155
  marker_pattern_shape=pattern_shape
 
157
 
158
  fig.update_layout(barmode='stack')
159
  else:
160
+ # 簡單計數
161
+ counts = df[x_column].value_counts().reset_index()
162
+ counts.columns = [x_column, 'count']
163
+ fig = px.bar(counts, x=x_column, y='count', **fig_params)
164
 
165
  elif chart_type == "群組長條圖":
166
  if group_column and group_column in df.columns: