File size: 11,337 Bytes
6165ba9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import logging
import re
from typing import Optional, List, Dict, Any

logger = logging.getLogger(__name__)

class LocalSummarizer:
    """

    Singleton-style wrapper for local LLM summarization.

    Enhances extraction using robust heuristic rules and LLM generation with retry logic.

    """
    _tokenizer = None
    _model = None
    _model_name = "google/flan-t5-small"

    @classmethod
    def _load_model(cls):
        """Lazy load the model and tokenizer directly"""
        if cls._model is None:
            try:
                from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
                import transformers
                logger.info(f"⏳ Loading summarization model ({cls._model_name})...")
                
                old_verbosity = transformers.logging.get_verbosity()
                transformers.logging.set_verbosity_error()
                
                cls._tokenizer = AutoTokenizer.from_pretrained(cls._model_name)
                cls._model = AutoModelForSeq2SeqLM.from_pretrained(cls._model_name)
                
                transformers.logging.set_verbosity(old_verbosity)
                logger.info("✅ Summarization model loaded successfully")
            except Exception as e:
                logger.error(f"❌ Failed to load summarization model: {e}")
                cls._model = False # Mark as failed

    @staticmethod
    def _strip_yaml_frontmatter(text: str) -> str:
        """Strip the YAML frontmatter enclosed in ---"""
        return re.sub(r'^---\s*\n.*?\n---\s*\n', '', text, flags=re.MULTILINE | re.DOTALL)

    @staticmethod
    def _extract_candidates(text: str) -> List[str]:
        candidates = []
        
        # 1. Section Headers (support "1. Introduction")
        heading_matches = re.finditer(r'^#+\s*(?:\d+[\.\)]?\s*)?(Description|Model [dD]escription|Model Overview|Overview|Introduction|Summary|モデル概要|Model Details)[^\n]*\n(.*?)(?=\n#+\s|\Z)', text, flags=re.MULTILINE | re.DOTALL)
        for match in heading_matches:
            if match.group(2).strip():
                candidates.append(match.group(2).strip())
                
        # 2. Inline Labels
        inline_matches = re.finditer(r'(?:Description:|Overview:|### Description:)\s*(.*?)(?=\n\n|\Z)', text, flags=re.DOTALL | re.IGNORECASE)
        for match in inline_matches:
            if match.group(1).strip():
                candidates.append(match.group(1).strip())
                
        # 3. Auto-generated fine-tuned leading sentences
        tuned_matches = re.finditer(r'^(?:The .*model is a .*|This model is a fine-tuned version of.*|This is a fine-tuned.*)', text, flags=re.MULTILINE | re.IGNORECASE)
        for match in tuned_matches:
            candidates.append(match.group(0).strip())
            
        # 4. Fallback: First meaningful paragraph
        # Strip some HTML first just for the fallback rule
        html_stripped = re.sub(r'<[^>]+>', '', text)
        paragraphs = re.split(r'\n\s*\n', html_stripped)
        for p in paragraphs:
            p = p.strip()
            if not p:
                continue
            if p.startswith('#'):
                continue
            # Skip heavy markdown like links/images/badges and github alerts
            if p.startswith('[!') or p.startswith('<a href') or p.startswith('> [!'):
                continue
            # If a paragraph has many links (like a table of contents / link directory)
            if p.count('](') > 3 or p.count('http') > 3:
                continue
            if len(p) > 50:
                candidates.append(p)
                break
                
        return candidates

    @staticmethod
    def _score_candidate(text: str) -> float:
        score = 0.0
        text_lower = text.lower()
        
        # Length score (sweet spot between 100 and 500 chars)
        if 50 < len(text) < 1000:
            score += 10.0
            
        # Reward definitional patterns
        if "is a" in text_lower or "fine-tuned version of" in text_lower or "trained on" in text_lower or "designed for" in text_lower:
            score += 20.0
            
        # Penalize bad patterns
        if "leaderboard" in text_lower or "benchmark" in text_lower or "results" in text_lower:
            score -= 50.0
        if "install" in text_lower or "how to run" in text_lower or "pip install" in text_lower or "read our guide" in text_lower:
            score -= 30.0
            
        # Penalize table/code-heavy paragraphs and bullet points
        if text.count('|') > 5 or text.count('```') >= 1 or text.count('\n- ') > 2 or text.count('\n* ') > 2:
            score -= 50.0
            
        return score

    @staticmethod
    def _clean_text(text: str) -> str:
        # Remove HTML
        from bs4 import BeautifulSoup
        try:
            soup = BeautifulSoup(text, "html.parser")
            for tag in soup(["style", "script"]):
                tag.decompose()
            text = soup.get_text(separator=' ')
        except Exception:
            pass
            
        # Remove markdown images
        text = re.sub(r'!\[.*?\]\([^)]+\)', '', text)
        # Convert links to just text
        text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', text)
        # Remove code blocks
        text = re.sub(r'```.*?```', '', text, flags=re.DOTALL)
        # Remove inline code
        text = re.sub(r'`[^`]*`', '', text)
        # Remove tables
        text = re.sub(r'\|.*?\|', '', text)
        text = re.sub(r'(?m)^[-:| ]+$', '', text) # table separators
        
        # Remove boilerplate line by line
        lines = text.split('\n')
        clean_lines = []
        for line in lines:
            line_lower = line.lower()
            if 'generated automatically' in line_lower and 'model card' in line_lower:
                continue
            if 'completed by the model author' in line_lower:
                continue
            if 'model cards for model reporting' in line_lower:
                continue
            clean_lines.append(line)
        text = '\n'.join(clean_lines)
        
        # Clean up whitespace
        text = re.sub(r'\s+', ' ', text).strip()
        
        return text

    @classmethod
    def _generate(cls, prompt: str, max_output_chars: int) -> Optional[str]:
        if cls._model is None:
            cls._load_model()
        if not cls._model or not cls._tokenizer:
            return None
            
        try:
            inputs = cls._tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
            generate_kwargs = {
                "max_length": 128,  # Increased by ~30% from 64
                "min_length": 15,  # Avoid single word outputs
                "do_sample": False,
                "num_beams": 4,
                "early_stopping": True,
                "repetition_penalty": 2.0
            }
            summary_ids = cls._model.generate(inputs["input_ids"], **generate_kwargs)
            summary = cls._tokenizer.decode(summary_ids[0], skip_special_tokens=True)
            
            summary = summary.strip()
            
            # Remove "Output:" prefix if present
            if summary.lower().startswith("output:"):
                summary = re.sub(r'^Output:\s*', '', summary, flags=re.IGNORECASE)
                
            if len(summary) > max_output_chars:
                return summary[:max_output_chars-3] + "..."
            return summary
        except Exception as e:
            logger.warning(f"⚠️ Generation failed: {e}")
            return None

    @staticmethod
    def _is_valid_summary(summary: str, model_id: str) -> bool:
        if not summary or len(summary) < 15:
            return False
            
        summary_lower = summary.lower()
        model_name = model_id.split('/')[-1].lower()
        
        if summary_lower == model_name or summary_lower == f"{model_name} model":
            return False
            
        # Check for markdown/html artifacts
        if '#' in summary or '<' in summary or '>' in summary or '*' in summary:
            return False
            
        # Check for instruction-like text
        if summary_lower.startswith("to install") or summary_lower.startswith("how to") or "pip install" in summary_lower:
            return False
            
        # Refuse literally copying bullet points (e.g. from table)
        if "- type:" in summary_lower or "number of parameters:" in summary_lower:
            return False
            
        return True

    @classmethod
    def summarize(cls, text: str, max_output_chars: int = 332, model_id: str = "") -> Optional[str]:
        """

        Robustly extract and summarize model description.

        """
        if not text or not text.strip():
            return None
            
        # 1. Strip YAML safely
        text_without_yaml = cls._strip_yaml_frontmatter(text)
        
        # 2. Extract multiple candidate description blocks
        candidates = cls._extract_candidates(text_without_yaml)
        
        if not candidates:
            # Fallback if candidates are absolutely empty
            candidates = [text_without_yaml[:1000]]
            
        # 3. Score candidates and pick best
        scored_candidates = [(c, cls._score_candidate(c)) for c in candidates]
        best_candidate = max(scored_candidates, key=lambda x: x[1])[0]
        
        # 4. Clean aggressively
        cleaned_text = cls._clean_text(best_candidate)
        
        if not cleaned_text.strip():
            return None
            
        # Extract just the first few sentences of the cleaned text to avoid confusing the small model 
        # with training details that usually appear at the end of the paragraph.
        sentences = re.split(r'(?<=[.!?])\s+', cleaned_text)
        short_text = " ".join(sentences[:3])
            
        # 5 & 6 & 7. Summarize, Validate, Retry, Fallback
        prompt1 = f"In one sentence, explain what this AI model is designed to do based on this description:\n\n{short_text}"
        
        summary = cls._generate(prompt1, max_output_chars)
        
        if summary and cls._is_valid_summary(summary, model_id):
            return summary
            
        # Retry with stricter prompt
        logger.info("⚠️ First summary invalid, retrying with stricter prompt.")
        prompt2 = f"Summarize the main purpose of this AI model in one complete sentence:\n\n{cleaned_text}"
        summary2 = cls._generate(prompt2, max_output_chars)
        
        if summary2 and cls._is_valid_summary(summary2, model_id):
            return summary2
            
        # Fallback to cleaned text (first 1-2 sentences)
        logger.info("⚠️ Both LLM summaries invalid, falling back to cleaned extracted text.")
        sentences = re.split(r'(?<=[.!?])\s+', cleaned_text)
        fallback_summary = " ".join(sentences[:2])
        if len(fallback_summary) > max_output_chars:
             return fallback_summary[:max_output_chars-3] + "..."
        return fallback_summary