Some Guy commited on
Commit
84f2f65
·
0 Parent(s):

Initial commit: text saliency pro with Gemma 2B

Browse files
Files changed (6) hide show
  1. .gitignore +49 -0
  2. Dockerfile +31 -0
  3. main.py +109 -0
  4. requirements.txt +6 -0
  5. static/index.html +228 -0
  6. test_playwright.py +56 -0
.gitignore ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # Virtual Environments
30
+ venv/
31
+ env/
32
+ .env/
33
+ ENV/
34
+ env.bak/
35
+ venv.bak/
36
+
37
+ # Logs and databases
38
+ *.log
39
+ result.png
40
+ server.log
41
+
42
+ # IDEs
43
+ .idea/
44
+ .vscode/
45
+ *.swp
46
+ *.swo
47
+
48
+ # macOS
49
+ .DS_Store
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use official Python image
2
+ FROM python:3.9-slim
3
+
4
+ # Set the working directory
5
+ WORKDIR /app
6
+
7
+ # Install system dependencies (needed for compiling some python packages if required)
8
+ RUN apt-get update && apt-get install -y --no-install-recommends \
9
+ build-essential \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ # Copy requirements
13
+ COPY requirements.txt .
14
+
15
+ # Install Python dependencies
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Copy the rest of the application
19
+ COPY . .
20
+
21
+ # Hugging Face Spaces require running as a non-root user (UID 1000)
22
+ RUN useradd -m -u 1000 user
23
+ USER user
24
+ ENV HOME=/home/user \
25
+ PATH=/home/user/.local/bin:$PATH
26
+
27
+ # Expose the port HF Spaces uses
28
+ EXPOSE 7860
29
+
30
+ # Command to run the application
31
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from fastapi import FastAPI
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+ import uvicorn
8
+ import os
9
+
10
+ app = FastAPI()
11
+
12
+ app.add_middleware(
13
+ CORSMiddleware,
14
+ allow_origins=["*"],
15
+ allow_credentials=True,
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+ app.mount("/static", StaticFiles(directory="static"), name="static")
21
+
22
+ model_id = "google/gemma-2b"
23
+
24
+ # Load the model and tokenizer globally.
25
+ # Use MPS if available, otherwise CPU. MPS (Metal Performance Shaders) works well on modern Macs.
26
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
27
+ print(f"Loading {model_id} on {device}...")
28
+
29
+ try:
30
+ hf_token = os.environ.get("HF_TOKEN")
31
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ model_id,
34
+ torch_dtype=torch.bfloat16,
35
+ attn_implementation="eager",
36
+ token=hf_token
37
+ ).to(device)
38
+ print("Model loaded successfully.")
39
+ except Exception as e:
40
+ print(f"Error loading model: {e}")
41
+ print("Make sure you are logged into Hugging Face and have access to the Gemma model.")
42
+ print("Run `huggingface-cli login` in your terminal.")
43
+
44
+ class TextRequest(BaseModel):
45
+ text: str
46
+
47
+ @app.post("/analyze")
48
+ async def analyze_text(request: TextRequest):
49
+ text = request.text
50
+ if not text.strip():
51
+ return {"tokens": [], "scores": []}
52
+
53
+ inputs = tokenizer(text, return_tensors="pt").to(device)
54
+
55
+ with torch.no_grad():
56
+ # Ensure we ask the model to output attentions explicitly
57
+ outputs = model(**inputs, output_attentions=True)
58
+
59
+ # Check if attentions are actually returned
60
+ if not outputs.attentions:
61
+ print("Warning: Model did not return attentions.")
62
+ return {"words": []}
63
+
64
+ # outputs.attentions is a tuple of (batch_size, num_heads, sequence_length, sequence_length)
65
+ # Get the last layer's attention
66
+ attentions = outputs.attentions[-1]
67
+
68
+ # Average across all heads
69
+ avg_attention = attentions[0].mean(dim=0) # shape: (seq_len, seq_len)
70
+
71
+ # Calculate importance: sum of attention each token *receives* from the sequence
72
+ importance = avg_attention.sum(dim=0).cpu().float().numpy()
73
+
74
+ if len(importance) > 1:
75
+ # Normalize to 0-1, optionally excluding the first token (<bos>) from max/min calculation
76
+ # as <bos> often has very high attention, skewing the rest
77
+ min_score = importance[1:].min()
78
+ max_score = importance[1:].max()
79
+
80
+ normalized_scores = (importance - min_score) / (max_score - min_score)
81
+ # Keep <bos> at max score
82
+ normalized_scores[0] = 1.0
83
+ normalized_scores = normalized_scores.clip(0, 1)
84
+ else:
85
+ normalized_scores = [1.0] * len(importance)
86
+
87
+ input_ids = inputs["input_ids"][0].tolist()
88
+ tokens = tokenizer.convert_ids_to_tokens(input_ids)
89
+
90
+ result = []
91
+ for i, t in enumerate(tokens):
92
+ # Decode properly
93
+ word = tokenizer.decode([input_ids[i]])
94
+ # Special check for Gemma, decoding often removes spaces incorrectly or leaves tokens empty
95
+ # Let's clean the raw token just in case
96
+ raw_clean = t.replace('\u2581', ' ')
97
+
98
+ # We will pass both decoded word and raw cleaned token to frontend to help render
99
+ result.append({
100
+ "token": raw_clean,
101
+ "word": word,
102
+ "score": float(normalized_scores[i])
103
+ })
104
+
105
+ return {"words": result}
106
+
107
+ if __name__ == "__main__":
108
+ port = int(os.environ.get("PORT", 7860))
109
+ uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ torch
4
+ transformers
5
+ pydantic
6
+ accelerate
static/index.html ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Text Saliency Pro</title>
7
+ <style>
8
+ body {
9
+ font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
10
+ max-width: 800px;
11
+ margin: 0 auto;
12
+ padding: 2rem;
13
+ line-height: 1.5;
14
+ background-color: #f9fafb;
15
+ color: #111827;
16
+ }
17
+
18
+ h1 {
19
+ font-size: 2.5rem;
20
+ margin-bottom: 1rem;
21
+ text-align: center;
22
+ }
23
+
24
+ p.description {
25
+ text-align: center;
26
+ color: #4b5563;
27
+ margin-bottom: 2rem;
28
+ }
29
+
30
+ .container {
31
+ background: white;
32
+ padding: 2rem;
33
+ border-radius: 0.5rem;
34
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06);
35
+ }
36
+
37
+ textarea {
38
+ width: 100%;
39
+ height: 150px;
40
+ padding: 0.75rem;
41
+ border: 1px solid #d1d5db;
42
+ border-radius: 0.375rem;
43
+ font-size: 1rem;
44
+ resize: vertical;
45
+ margin-bottom: 1rem;
46
+ box-sizing: border-box;
47
+ }
48
+
49
+ .controls {
50
+ display: flex;
51
+ align-items: center;
52
+ justify-content: space-between;
53
+ margin-bottom: 1.5rem;
54
+ flex-wrap: wrap;
55
+ gap: 1rem;
56
+ }
57
+
58
+ .slider-group {
59
+ display: flex;
60
+ align-items: center;
61
+ gap: 1rem;
62
+ flex-grow: 1;
63
+ }
64
+
65
+ input[type="range"] {
66
+ flex-grow: 1;
67
+ max-width: 300px;
68
+ }
69
+
70
+ button {
71
+ background-color: #3b82f6;
72
+ color: white;
73
+ border: none;
74
+ padding: 0.5rem 1.5rem;
75
+ font-size: 1rem;
76
+ border-radius: 0.375rem;
77
+ cursor: pointer;
78
+ transition: background-color 0.2s;
79
+ }
80
+
81
+ button:hover {
82
+ background-color: #2563eb;
83
+ }
84
+
85
+ button:disabled {
86
+ background-color: #9ca3af;
87
+ cursor: not-allowed;
88
+ }
89
+
90
+ #result-container {
91
+ margin-top: 2rem;
92
+ padding: 1.5rem;
93
+ background-color: #f3f4f6;
94
+ border-radius: 0.375rem;
95
+ min-height: 100px;
96
+ white-space: pre-wrap;
97
+ font-size: 1.125rem;
98
+ }
99
+
100
+ /* Token specific styles */
101
+ .token {
102
+ transition: font-weight 0.2s;
103
+ }
104
+
105
+ .highlighted {
106
+ font-weight: 800; /* Extra bold */
107
+ color: #000;
108
+ }
109
+
110
+ #loading {
111
+ display: none;
112
+ color: #6b7280;
113
+ text-align: center;
114
+ margin-top: 1rem;
115
+ }
116
+ </style>
117
+ </head>
118
+ <body>
119
+
120
+ <h1>Text Saliency Pro</h1>
121
+ <p class="description">Improve reading comprehension using LLM attention vectors.<br>Words with attention above the threshold will be bolded.</p>
122
+
123
+ <div class="container">
124
+ <textarea id="text-input" placeholder="Enter or paste your text here...">In this project I want to use the attention vectors of a llm to bold the most important words in an input text to improve reading comprehension.</textarea>
125
+
126
+ <div class="controls">
127
+ <button id="analyze-btn">Analyze Text</button>
128
+ <div class="slider-group">
129
+ <label for="threshold">Attention Threshold: <span id="threshold-val">0.50</span></label>
130
+ <input type="range" id="threshold" min="0" max="1" step="0.01" value="0.5">
131
+ </div>
132
+ </div>
133
+
134
+ <div id="loading">Analyzing attention vectors with Gemma 2B... Please wait.</div>
135
+
136
+ <div id="result-container">
137
+ <!-- Processed text will appear here -->
138
+ </div>
139
+ </div>
140
+
141
+ <script>
142
+ const inputArea = document.getElementById('text-input');
143
+ const analyzeBtn = document.getElementById('analyze-btn');
144
+ const thresholdSlider = document.getElementById('threshold');
145
+ const thresholdVal = document.getElementById('threshold-val');
146
+ const resultContainer = document.getElementById('result-container');
147
+ const loading = document.getElementById('loading');
148
+
149
+ let currentTokens = []; // Array of {token: str, word: str, score: float}
150
+
151
+ // Update threshold display
152
+ thresholdSlider.addEventListener('input', (e) => {
153
+ thresholdVal.textContent = parseFloat(e.target.value).toFixed(2);
154
+ renderTokens(); // Re-render instantly when slider changes
155
+ });
156
+
157
+ // Analyze text when button is clicked
158
+ analyzeBtn.addEventListener('click', async () => {
159
+ const text = inputArea.value.trim();
160
+ if (!text) return;
161
+
162
+ // Update UI state
163
+ analyzeBtn.disabled = true;
164
+ loading.style.display = 'block';
165
+ resultContainer.innerHTML = '';
166
+
167
+ try {
168
+ // Call the FastAPI backend
169
+ const response = await fetch('/analyze', {
170
+ method: 'POST',
171
+ headers: {
172
+ 'Content-Type': 'application/json'
173
+ },
174
+ body: JSON.stringify({ text })
175
+ });
176
+
177
+ if (!response.ok) {
178
+ throw new Error('Network response was not ok');
179
+ }
180
+
181
+ const data = await response.json();
182
+ currentTokens = data.words || [];
183
+ renderTokens();
184
+
185
+ } catch (error) {
186
+ console.error('Error analyzing text:', error);
187
+ resultContainer.innerHTML = '<span style="color: red;">Error analyzing text. Is the backend running?</span>';
188
+ } finally {
189
+ // Restore UI state
190
+ analyzeBtn.disabled = false;
191
+ loading.style.display = 'none';
192
+ }
193
+ });
194
+
195
+ // Render the tokens based on the current threshold
196
+ function renderTokens() {
197
+ if (!currentTokens.length) return;
198
+
199
+ const threshold = parseFloat(thresholdSlider.value);
200
+ resultContainer.innerHTML = ''; // Clear existing
201
+
202
+ currentTokens.forEach((item, index) => {
203
+ // Skip the <bos> token (usually first)
204
+ if (index === 0 && (item.token.includes('<bos>') || item.word.includes('<bos>'))) {
205
+ return;
206
+ }
207
+
208
+ const span = document.createElement('span');
209
+ span.className = 'token';
210
+
211
+ // Add bolding if score is above threshold
212
+ if (item.score >= threshold) {
213
+ span.classList.add('highlighted');
214
+ }
215
+
216
+ // If the raw token started with a space, add it here.
217
+ // The backend replaced the special block char with a normal space.
218
+ // Depending on the tokenizer, 'word' might be better to display if it represents whole words,
219
+ // but for subwords, using the raw token with correct spacing is usually best.
220
+ let displayText = item.token;
221
+
222
+ span.textContent = displayText;
223
+ resultContainer.appendChild(span);
224
+ });
225
+ }
226
+ </script>
227
+ </body>
228
+ </html>
test_playwright.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from playwright.sync_api import sync_playwright
2
+ import time
3
+
4
+ def test_app():
5
+ with sync_playwright() as p:
6
+ print("Launching browser...")
7
+ browser = p.chromium.launch(headless=True)
8
+ page = browser.new_page()
9
+
10
+ url = "http://localhost:8000/static/index.html"
11
+ print(f"Navigating to {url}...")
12
+ page.goto(url)
13
+
14
+ # We'll just use the default text already in the textarea:
15
+ # "In this project I want to use the attention vectors of a llm to bold the most important words in an input text to improve reading comprehension."
16
+
17
+ print("Clicking the 'Analyze Text' button...")
18
+ page.click("#analyze-btn")
19
+
20
+ print("Waiting for the analysis to finish (this might take a few seconds)...")
21
+ # Wait for the loading text to disappear and spans to appear
22
+ page.wait_for_selector(".token", timeout=60000)
23
+
24
+ # Get all tokens and their classes
25
+ tokens = page.query_selector_all(".token")
26
+
27
+ print("\n--- Results ---")
28
+ highlighted_words = []
29
+ full_text = []
30
+
31
+ for token in tokens:
32
+ text = token.inner_text()
33
+ classes = token.get_attribute("class")
34
+
35
+ # Format output
36
+ if "highlighted" in classes:
37
+ full_text.append(f"**{text}**")
38
+ highlighted_words.append(text)
39
+ else:
40
+ full_text.append(text)
41
+
42
+ print("Full output with bolded words (marked by **):")
43
+ # Simple join (there might be spaces in the tokens themselves based on Gemma's tokenizer)
44
+ print("".join(full_text))
45
+
46
+ print("\nWords that crossed the attention threshold:")
47
+ print(highlighted_words)
48
+
49
+ print("\nSaving screenshot to result.png...")
50
+ page.screenshot(path="result.png", full_page=True)
51
+
52
+ browser.close()
53
+ print("Done!")
54
+
55
+ if __name__ == "__main__":
56
+ test_app()