nsschw commited on
Commit
02b5cfb
·
1 Parent(s): 8cc7f61

Refactor analysis functions and add default prompt for statement analysis

Browse files
Files changed (2) hide show
  1. app.py +42 -129
  2. default-prompt.txt +9 -0
app.py CHANGED
@@ -22,118 +22,54 @@ for i in range(1, 6):
22
  if tokens:
23
  likert_tokens[i] = tokens[0]
24
 
25
- print(f"Likert tokens: {likert_tokens}")
26
 
27
- def analyze_statement(statement):
28
- """
29
- Analyze a statement and return Likert probabilities using chat template
30
- """
31
- try:
32
- # Create chat messages
33
- messages = [
34
- {
35
- "role": "user",
36
- "content": f"""How would you respond to the following statement:
37
- '{statement}'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- Respond only with a number from 1 to 5, where:
40
- 1 = Strongly Disagree
41
- 2 = Disagree
42
- 3 = Neutral
43
- 4 = Agree
44
- 5 = Strongly Agree"""
45
- }
46
- ]
47
-
48
- # Apply chat template
49
- prompt = tokenizer.apply_chat_template(
50
- messages,
51
- tokenize=False,
52
- add_generation_prompt=True
53
- )
54
-
55
- # Tokenize
56
- inputs = tokenizer(prompt, return_tensors="pt")
57
-
58
- # Generate with output scores
59
- with torch.no_grad():
60
- outputs = model.generate(
61
- **inputs,
62
- max_new_tokens=1,
63
- return_dict_in_generate=True,
64
- output_scores=True,
65
- do_sample=False,
66
- pad_token_id=tokenizer.eos_token_id
67
- )
68
-
69
- # Get probabilities for first generated token
70
- if outputs.scores:
71
- logits = outputs.scores[0][0] # First token, first batch
72
- probs = torch.softmax(logits, dim=-1)
73
-
74
- # Extract Likert probabilities
75
- likert_probs = {}
76
- for value, token_id in likert_tokens.items():
77
- likert_probs[value] = probs[token_id].item()
78
-
79
- # Create simple bar chart
80
- fig, ax = plt.subplots(figsize=(8, 5))
81
- values = list(likert_probs.keys())
82
- probabilities = list(likert_probs.values())
83
-
84
- bars = ax.bar(values, probabilities, color='steelblue', alpha=0.8, edgecolor='navy')
85
-
86
- # Add value labels
87
- for bar, prob in zip(bars, probabilities):
88
- height = bar.get_height()
89
- ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
90
- f'{prob:.3f}', ha='center', va='bottom')
91
-
92
- ax.set_xlabel('Likert Scale Value')
93
- ax.set_ylabel('Probability')
94
- ax.set_title('Response Probability Distribution')
95
- ax.set_xticks(values)
96
- ax.set_ylim(0, max(probabilities) * 1.2 if probabilities else 1)
97
- ax.grid(True, axis='y', alpha=0.3)
98
-
99
- plt.tight_layout()
100
-
101
- # Format probabilities text
102
- prob_text = "\n".join([f"{k}: {v:.4f}" for k, v in likert_probs.items()])
103
-
104
- # Show what the model actually generated (for debugging)
105
- generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
106
- debug_info = f"Generated: {generated_text[-50:]}" # Last 50 chars
107
-
108
- return fig, prob_text, f"✅ Analysis complete\n{debug_info}"
109
-
110
- else:
111
- return None, "", "❌ No scores generated"
112
-
113
- except Exception as e:
114
- return None, "", f"❌ Error: {str(e)}"
115
 
116
  def analyze_with_persona(statement, persona=""):
117
  """
118
- Analyze with optional persona
119
  """
120
  try:
 
 
 
 
 
 
121
  # Create chat messages with optional system prompt
122
  messages = []
123
  if persona.strip():
124
  messages.append({"role": "system", "content": persona.strip()})
125
 
126
  messages.append({
127
- "role": "user",
128
- "content": f"""How would you respond to the following statement:
129
- '{statement}'
130
-
131
- Respond only with a number from 1 to 5, where:
132
- 1 = Strongly Disagree
133
- 2 = Disagree
134
- 3 = Neutral
135
- 4 = Agree
136
- 5 = Strongly Agree"""
137
  })
138
 
139
  # Apply chat template
@@ -150,7 +86,7 @@ Respond only with a number from 1 to 5, where:
150
  with torch.no_grad():
151
  outputs = model.generate(
152
  **inputs,
153
- max_new_tokens=3,
154
  return_dict_in_generate=True,
155
  output_scores=True,
156
  do_sample=False,
@@ -167,39 +103,16 @@ Respond only with a number from 1 to 5, where:
167
  for value, token_id in likert_tokens.items():
168
  likert_probs[value] = probs[token_id].item()
169
 
170
- # Create simple bar chart
171
- fig, ax = plt.subplots(figsize=(8, 5))
172
- values = list(likert_probs.keys())
173
- probabilities = list(likert_probs.values())
174
-
175
- bars = ax.bar(values, probabilities, color='steelblue', alpha=0.8, edgecolor='navy')
176
-
177
- # Add value labels
178
- for bar, prob in zip(bars, probabilities):
179
- height = bar.get_height()
180
- ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
181
- f'{prob:.3f}', ha='center', va='bottom')
182
-
183
- ax.set_xlabel('Likert Scale Value')
184
- ax.set_ylabel('Probability')
185
- title = 'Response Probability Distribution'
186
- if persona.strip():
187
- title += f'\nPersona: {persona[:50]}...' if len(persona) > 50 else f'\nPersona: {persona}'
188
- ax.set_title(title)
189
- ax.set_xticks(values)
190
- ax.set_ylim(0, max(probabilities) * 1.2 if probabilities else 1)
191
- ax.grid(True, axis='y', alpha=0.3)
192
-
193
- plt.tight_layout()
194
 
195
  # Format probabilities text
196
  prob_text = "\n".join([f"{k}: {v:.4f}" for k, v in likert_probs.items()])
197
 
198
- # Show what the model actually generated
199
- generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
200
- debug_info = f"Generated: {generated_text[-100:]}" # Last 100 chars
201
-
202
- return fig, prob_text, f"✅ Analysis complete\n{debug_info}"
203
 
204
  else:
205
  return None, "", "❌ No scores generated"
 
22
  if tokens:
23
  likert_tokens[i] = tokens[0]
24
 
 
25
 
26
+ def create_probability_plot(likert_probs, persona=""):
27
+ """Create a bar chart for Likert scale probabilities"""
28
+ fig, ax = plt.subplots(figsize=(8, 5))
29
+ values = list(likert_probs.keys())
30
+ probabilities = list(likert_probs.values())
31
+
32
+ bars = ax.bar(values, probabilities, color='steelblue', alpha=0.8, edgecolor='navy')
33
+
34
+ # Add value labels
35
+ for bar, prob in zip(bars, probabilities):
36
+ height = bar.get_height()
37
+ ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
38
+ f'{prob:.3f}', ha='center', va='bottom')
39
+
40
+ ax.set_xlabel('Likert Scale Value')
41
+ ax.set_ylabel('Probability')
42
+ title = 'Response Probability Distribution'
43
+ if persona.strip():
44
+ title += f'\nPersona: {persona[:50]}...' if len(persona) > 50 else f'\nPersona: {persona}'
45
+ ax.set_title(title)
46
+ ax.set_xticks(values)
47
+ ax.set_ylim(0, max(probabilities) * 1.2 if probabilities else 1)
48
+ ax.grid(True, axis='y', alpha=0.3)
49
+
50
+ plt.tight_layout()
51
+ return fig
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  def analyze_with_persona(statement, persona=""):
55
  """
56
+ Analyze with persona prompt
57
  """
58
  try:
59
+
60
+ # read default prompt
61
+ with open("default_prompt.txt", "r") as f:
62
+ default_prompt = f.read().strip()
63
+
64
+
65
  # Create chat messages with optional system prompt
66
  messages = []
67
  if persona.strip():
68
  messages.append({"role": "system", "content": persona.strip()})
69
 
70
  messages.append({
71
+ "role": "user",
72
+ "content": default_prompt.format(statement="Your actual statement here")
 
 
 
 
 
 
 
 
73
  })
74
 
75
  # Apply chat template
 
86
  with torch.no_grad():
87
  outputs = model.generate(
88
  **inputs,
89
+ max_new_tokens=1,
90
  return_dict_in_generate=True,
91
  output_scores=True,
92
  do_sample=False,
 
103
  for value, token_id in likert_tokens.items():
104
  likert_probs[value] = probs[token_id].item()
105
 
106
+ # Create probability plot
107
+ fig = create_probability_plot(likert_probs, persona)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
  # Format probabilities text
110
  prob_text = "\n".join([f"{k}: {v:.4f}" for k, v in likert_probs.items()])
111
 
112
+ # Show what the model actually generated including input and special tokens
113
+ debug_info = f"Input: {prompt}...\n\n"
114
+ debug_info += f"Output Tokens: {outputs.sequences[0]}"
115
+ return fig, f"✅ Analysis complete\n\n{debug_info}"
 
116
 
117
  else:
118
  return None, "", "❌ No scores generated"
default-prompt.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ How would you respond to the following statement:
2
+ '{statement}'
3
+
4
+ Respond only with a number from 1 to 5, where:
5
+ 1 = Strongly Disagree
6
+ 2 = Disagree
7
+ 3 = Neutral
8
+ 4 = Agree
9
+ 5 = Strongly Agree