File size: 11,337 Bytes
6165ba9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 | import logging
import re
from typing import Optional, List, Dict, Any
logger = logging.getLogger(__name__)
class LocalSummarizer:
"""
Singleton-style wrapper for local LLM summarization.
Enhances extraction using robust heuristic rules and LLM generation with retry logic.
"""
_tokenizer = None
_model = None
_model_name = "google/flan-t5-small"
@classmethod
def _load_model(cls):
"""Lazy load the model and tokenizer directly"""
if cls._model is None:
try:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import transformers
logger.info(f"⏳ Loading summarization model ({cls._model_name})...")
old_verbosity = transformers.logging.get_verbosity()
transformers.logging.set_verbosity_error()
cls._tokenizer = AutoTokenizer.from_pretrained(cls._model_name)
cls._model = AutoModelForSeq2SeqLM.from_pretrained(cls._model_name)
transformers.logging.set_verbosity(old_verbosity)
logger.info("✅ Summarization model loaded successfully")
except Exception as e:
logger.error(f"❌ Failed to load summarization model: {e}")
cls._model = False # Mark as failed
@staticmethod
def _strip_yaml_frontmatter(text: str) -> str:
"""Strip the YAML frontmatter enclosed in ---"""
return re.sub(r'^---\s*\n.*?\n---\s*\n', '', text, flags=re.MULTILINE | re.DOTALL)
@staticmethod
def _extract_candidates(text: str) -> List[str]:
candidates = []
# 1. Section Headers (support "1. Introduction")
heading_matches = re.finditer(r'^#+\s*(?:\d+[\.\)]?\s*)?(Description|Model [dD]escription|Model Overview|Overview|Introduction|Summary|モデル概要|Model Details)[^\n]*\n(.*?)(?=\n#+\s|\Z)', text, flags=re.MULTILINE | re.DOTALL)
for match in heading_matches:
if match.group(2).strip():
candidates.append(match.group(2).strip())
# 2. Inline Labels
inline_matches = re.finditer(r'(?:Description:|Overview:|### Description:)\s*(.*?)(?=\n\n|\Z)', text, flags=re.DOTALL | re.IGNORECASE)
for match in inline_matches:
if match.group(1).strip():
candidates.append(match.group(1).strip())
# 3. Auto-generated fine-tuned leading sentences
tuned_matches = re.finditer(r'^(?:The .*model is a .*|This model is a fine-tuned version of.*|This is a fine-tuned.*)', text, flags=re.MULTILINE | re.IGNORECASE)
for match in tuned_matches:
candidates.append(match.group(0).strip())
# 4. Fallback: First meaningful paragraph
# Strip some HTML first just for the fallback rule
html_stripped = re.sub(r'<[^>]+>', '', text)
paragraphs = re.split(r'\n\s*\n', html_stripped)
for p in paragraphs:
p = p.strip()
if not p:
continue
if p.startswith('#'):
continue
# Skip heavy markdown like links/images/badges and github alerts
if p.startswith('[!') or p.startswith('<a href') or p.startswith('> [!'):
continue
# If a paragraph has many links (like a table of contents / link directory)
if p.count('](') > 3 or p.count('http') > 3:
continue
if len(p) > 50:
candidates.append(p)
break
return candidates
@staticmethod
def _score_candidate(text: str) -> float:
score = 0.0
text_lower = text.lower()
# Length score (sweet spot between 100 and 500 chars)
if 50 < len(text) < 1000:
score += 10.0
# Reward definitional patterns
if "is a" in text_lower or "fine-tuned version of" in text_lower or "trained on" in text_lower or "designed for" in text_lower:
score += 20.0
# Penalize bad patterns
if "leaderboard" in text_lower or "benchmark" in text_lower or "results" in text_lower:
score -= 50.0
if "install" in text_lower or "how to run" in text_lower or "pip install" in text_lower or "read our guide" in text_lower:
score -= 30.0
# Penalize table/code-heavy paragraphs and bullet points
if text.count('|') > 5 or text.count('```') >= 1 or text.count('\n- ') > 2 or text.count('\n* ') > 2:
score -= 50.0
return score
@staticmethod
def _clean_text(text: str) -> str:
# Remove HTML
from bs4 import BeautifulSoup
try:
soup = BeautifulSoup(text, "html.parser")
for tag in soup(["style", "script"]):
tag.decompose()
text = soup.get_text(separator=' ')
except Exception:
pass
# Remove markdown images
text = re.sub(r'!\[.*?\]\([^)]+\)', '', text)
# Convert links to just text
text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', text)
# Remove code blocks
text = re.sub(r'```.*?```', '', text, flags=re.DOTALL)
# Remove inline code
text = re.sub(r'`[^`]*`', '', text)
# Remove tables
text = re.sub(r'\|.*?\|', '', text)
text = re.sub(r'(?m)^[-:| ]+$', '', text) # table separators
# Remove boilerplate line by line
lines = text.split('\n')
clean_lines = []
for line in lines:
line_lower = line.lower()
if 'generated automatically' in line_lower and 'model card' in line_lower:
continue
if 'completed by the model author' in line_lower:
continue
if 'model cards for model reporting' in line_lower:
continue
clean_lines.append(line)
text = '\n'.join(clean_lines)
# Clean up whitespace
text = re.sub(r'\s+', ' ', text).strip()
return text
@classmethod
def _generate(cls, prompt: str, max_output_chars: int) -> Optional[str]:
if cls._model is None:
cls._load_model()
if not cls._model or not cls._tokenizer:
return None
try:
inputs = cls._tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
generate_kwargs = {
"max_length": 128, # Increased by ~30% from 64
"min_length": 15, # Avoid single word outputs
"do_sample": False,
"num_beams": 4,
"early_stopping": True,
"repetition_penalty": 2.0
}
summary_ids = cls._model.generate(inputs["input_ids"], **generate_kwargs)
summary = cls._tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summary = summary.strip()
# Remove "Output:" prefix if present
if summary.lower().startswith("output:"):
summary = re.sub(r'^Output:\s*', '', summary, flags=re.IGNORECASE)
if len(summary) > max_output_chars:
return summary[:max_output_chars-3] + "..."
return summary
except Exception as e:
logger.warning(f"⚠️ Generation failed: {e}")
return None
@staticmethod
def _is_valid_summary(summary: str, model_id: str) -> bool:
if not summary or len(summary) < 15:
return False
summary_lower = summary.lower()
model_name = model_id.split('/')[-1].lower()
if summary_lower == model_name or summary_lower == f"{model_name} model":
return False
# Check for markdown/html artifacts
if '#' in summary or '<' in summary or '>' in summary or '*' in summary:
return False
# Check for instruction-like text
if summary_lower.startswith("to install") or summary_lower.startswith("how to") or "pip install" in summary_lower:
return False
# Refuse literally copying bullet points (e.g. from table)
if "- type:" in summary_lower or "number of parameters:" in summary_lower:
return False
return True
@classmethod
def summarize(cls, text: str, max_output_chars: int = 332, model_id: str = "") -> Optional[str]:
"""
Robustly extract and summarize model description.
"""
if not text or not text.strip():
return None
# 1. Strip YAML safely
text_without_yaml = cls._strip_yaml_frontmatter(text)
# 2. Extract multiple candidate description blocks
candidates = cls._extract_candidates(text_without_yaml)
if not candidates:
# Fallback if candidates are absolutely empty
candidates = [text_without_yaml[:1000]]
# 3. Score candidates and pick best
scored_candidates = [(c, cls._score_candidate(c)) for c in candidates]
best_candidate = max(scored_candidates, key=lambda x: x[1])[0]
# 4. Clean aggressively
cleaned_text = cls._clean_text(best_candidate)
if not cleaned_text.strip():
return None
# Extract just the first few sentences of the cleaned text to avoid confusing the small model
# with training details that usually appear at the end of the paragraph.
sentences = re.split(r'(?<=[.!?])\s+', cleaned_text)
short_text = " ".join(sentences[:3])
# 5 & 6 & 7. Summarize, Validate, Retry, Fallback
prompt1 = f"In one sentence, explain what this AI model is designed to do based on this description:\n\n{short_text}"
summary = cls._generate(prompt1, max_output_chars)
if summary and cls._is_valid_summary(summary, model_id):
return summary
# Retry with stricter prompt
logger.info("⚠️ First summary invalid, retrying with stricter prompt.")
prompt2 = f"Summarize the main purpose of this AI model in one complete sentence:\n\n{cleaned_text}"
summary2 = cls._generate(prompt2, max_output_chars)
if summary2 and cls._is_valid_summary(summary2, model_id):
return summary2
# Fallback to cleaned text (first 1-2 sentences)
logger.info("⚠️ Both LLM summaries invalid, falling back to cleaned extracted text.")
sentences = re.split(r'(?<=[.!?])\s+', cleaned_text)
fallback_summary = " ".join(sentences[:2])
if len(fallback_summary) > max_output_chars:
return fallback_summary[:max_output_chars-3] + "..."
return fallback_summary
|