nsschw commited on
Commit
9b87613
·
verified ·
1 Parent(s): 89d0124

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -38
app.py CHANGED
@@ -1,45 +1,124 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
- import numpy as np
5
  import matplotlib.pyplot as plt
6
- import seaborn as sns
7
- import random
8
- from datetime import datetime
9
-
10
- class AcademicOpinionOMatic:
11
- def __init__(self, model_name="microsoft/Phi-4-mini-instruct"):
12
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
13
- self.model = AutoModelForCausalLM.from_pretrained(
14
- model_name,
15
- torch_dtype="auto",
16
- device_map="auto",
17
- trust_remote_code=True
18
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Get token IDs for numbers 1-5
21
- self.likert_tokens = {
22
- i: self.tokenizer.encode(str(i), add_special_tokens=False)[0]
23
- for i in range(1, 6)
24
- }
25
- self.likert_token_ids = list(self.likert_tokens.values())
26
-
27
- def get_likert_probabilities(self, prompt, temperature=1.0):
28
- """Extract academic opinions with scientific precision!"""
 
29
 
30
- if isinstance(prompt, str):
31
- chat_prompt = [{"role": "user", "content": prompt}]
32
- elif isinstance(prompt, list):
33
- chat_prompt = prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  else:
35
- raise ValueError("Input must be string or chat messages (like a properly formatted citation!)")
36
-
37
- try:
38
- prompt_text = self.tokenizer.apply_chat_template(
39
- chat_prompt,
40
- tokenize=False,
41
- add_generation_prompt=True
42
- )
43
- except Exception:
44
- if isinstance(prompt, list):
45
- prompt_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
  import matplotlib.pyplot as plt
5
+
6
+ # Initialize model (using a small, fast model for proof of concept)
7
+ MODEL_NAME = "microsoft/DialoGPT-medium"
8
+
9
+ print("Loading model...")
10
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
12
+
13
+ # Add padding token if missing
14
+ if tokenizer.pad_token is None:
15
+ tokenizer.pad_token = tokenizer.eos_token
16
+
17
+ # Get token IDs for numbers 1-5
18
+ likert_tokens = {}
19
+ for i in range(1, 6):
20
+ tokens = tokenizer.encode(str(i), add_special_tokens=False)
21
+ if tokens:
22
+ likert_tokens[i] = tokens[0]
23
+
24
+ print(f"Likert tokens: {likert_tokens}")
25
+
26
+ def analyze_statement(statement):
27
+ """
28
+ Analyze a statement and return Likert probabilities
29
+ """
30
+ try:
31
+ # Create prompt
32
+ prompt = f"""Rate this statement from 1-5 where:
33
+ 1 = Strongly Disagree
34
+ 2 = Disagree
35
+ 3 = Neutral
36
+ 4 = Agree
37
+ 5 = Strongly Agree
38
+
39
+ Statement: "{statement}"
40
+
41
+ Rating:"""
42
+
43
+ # Tokenize
44
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
45
 
46
+ # Generate with output scores
47
+ with torch.no_grad():
48
+ outputs = model.generate(
49
+ **inputs,
50
+ max_new_tokens=1,
51
+ return_dict_in_generate=True,
52
+ output_scores=True,
53
+ do_sample=False,
54
+ pad_token_id=tokenizer.pad_token_id
55
+ )
56
 
57
+ # Get probabilities for first generated token
58
+ if outputs.scores:
59
+ logits = outputs.scores[0][0] # First token, first batch
60
+ probs = torch.softmax(logits, dim=-1)
61
+
62
+ # Extract Likert probabilities
63
+ likert_probs = {}
64
+ for value, token_id in likert_tokens.items():
65
+ likert_probs[value] = probs[token_id].item()
66
+
67
+ # Create simple bar chart
68
+ fig, ax = plt.subplots(figsize=(8, 5))
69
+ values = list(likert_probs.keys())
70
+ probabilities = list(likert_probs.values())
71
+
72
+ bars = ax.bar(values, probabilities, color='skyblue', edgecolor='navy')
73
+
74
+ # Add value labels
75
+ for bar, prob in zip(bars, probabilities):
76
+ height = bar.get_height()
77
+ ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
78
+ f'{prob:.3f}', ha='center', va='bottom')
79
+
80
+ ax.set_xlabel('Likert Scale')
81
+ ax.set_ylabel('Probability')
82
+ ax.set_title('Response Probability Distribution')
83
+ ax.set_xticks(values)
84
+ ax.set_ylim(0, max(probabilities) * 1.2)
85
+
86
+ plt.tight_layout()
87
+
88
+ # Format probabilities text
89
+ prob_text = "\n".join([f"{k}: {v:.4f}" for k, v in likert_probs.items()])
90
+
91
+ return fig, prob_text, "✅ Analysis complete"
92
+
93
  else:
94
+ return None, "", "❌ No scores generated"
95
+
96
+ except Exception as e:
97
+ return None, "", f"❌ Error: {str(e)}"
98
+
99
+ # Create Gradio interface
100
+ demo = gr.Interface(
101
+ fn=analyze_statement,
102
+ inputs=[
103
+ gr.Textbox(
104
+ label="Statement to Analyze",
105
+ placeholder="e.g., Climate change is a serious threat",
106
+ lines=3
107
+ )
108
+ ],
109
+ outputs=[
110
+ gr.Plot(label="Probability Distribution"),
111
+ gr.Textbox(label="Raw Probabilities", lines=6),
112
+ gr.Textbox(label="Status")
113
+ ],
114
+ title="🎯 Likert Scale Probability Extractor (Minimal)",
115
+ description="Enter a statement and see the probability distribution for Likert scale responses (1-5)",
116
+ examples=[
117
+ ["Climate change is a serious threat"],
118
+ ["Technology makes life better"],
119
+ ["Government should provide universal healthcare"]
120
+ ]
121
+ )
122
+
123
+ if __name__ == "__main__":
124
+ demo.launch()