File size: 14,379 Bytes
76d540d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 | """
AI Research Paper Helper - Equation Explainer Service
Parses and explains LaTeX equations in plain English.
"""
import re
import logging
from typing import List, Optional, Dict
from dataclasses import dataclass
import httpx
from config import settings, get_llm_config
logger = logging.getLogger(__name__)
@dataclass
class VariableExplanation:
"""Explanation of a variable in an equation."""
symbol: str
latex: str
description: str
@dataclass
class EquationExplanation:
"""Complete explanation of an equation."""
readable: str # Human-readable form
meaning: str # What it represents
variables: List[VariableExplanation]
importance: str # Why it matters
equation_type: str # loss, gradient, probability, etc.
class EquationExplainerService:
"""Service for explaining mathematical equations."""
# LaTeX symbol mappings
SYMBOL_MAP = {
r'\alpha': 'α', r'\beta': 'β', r'\gamma': 'γ', r'\delta': 'δ',
r'\epsilon': 'ε', r'\theta': 'θ', r'\lambda': 'λ', r'\mu': 'μ',
r'\sigma': 'σ', r'\phi': 'φ', r'\psi': 'ψ', r'\omega': 'ω',
r'\Sigma': 'Σ', r'\Pi': 'Π', r'\Omega': 'Ω',
r'\sum': 'Σ', r'\prod': 'Π', r'\int': '∫',
r'\infty': '∞', r'\partial': '∂', r'\nabla': '∇',
r'\leq': '≤', r'\geq': '≥', r'\neq': '≠', r'\approx': '≈',
r'\in': '∈', r'\forall': '∀', r'\exists': '∃',
r'\rightarrow': '→', r'\leftarrow': '←', r'\Rightarrow': '⇒',
r'\cdot': '·', r'\times': '×', r'\pm': '±'
}
# Common ML equation patterns
EQUATION_PATTERNS = {
'loss': [r'\\mathcal\{L\}', r'\\text\{loss\}', r'loss', r'J\(', r'L\('],
'gradient': [r'\\nabla', r'\\partial', r'\\frac\{\\partial'],
'probability': [r'\\mathbb\{P\}', r'\\Pr', r'p\(', r'P\('],
'expectation': [r'\\mathbb\{E\}', r'\\E\[', r'E\['],
'softmax': [r'softmax', r'\\text\{softmax\}', r'\\frac\{e\^'],
'attention': [r'attention', r'Attention', r'\\text\{Attention\}'],
'norm': [r'\|.*?\|', r'\\|.*?\\|', r'\\lVert', r'\\rVert']
}
# Common variable meanings in ML
COMMON_VARIABLES = {
'x': 'input data or features',
'y': 'target or label',
'w': 'weight parameter',
'W': 'weight matrix',
'b': 'bias term',
'θ': 'model parameters',
'α': 'learning rate or scaling factor',
'β': 'momentum coefficient or scaling factor',
'λ': 'regularization coefficient',
'σ': 'sigmoid function or standard deviation',
'μ': 'mean',
'ε': 'small constant for numerical stability',
'η': 'learning rate',
'L': 'loss function',
'J': 'cost function',
'n': 'number of samples',
'm': 'batch size or number of features',
'h': 'hidden state or layer',
'z': 'latent variable or pre-activation',
'a': 'activation',
'p': 'probability',
'q': 'approximate distribution',
'K': 'number of classes',
'T': 'temperature or time steps',
'd': 'dimension'
}
async def explain(
self,
equation: str,
context: Optional[str] = None,
format: str = 'latex'
) -> EquationExplanation:
"""
Explain an equation in plain English.
Args:
equation: LaTeX or MathML equation string
context: Surrounding text for better understanding
format: 'latex' or 'mathml'
Returns:
EquationExplanation with all components
"""
# Clean the equation
clean_eq = self._clean_equation(equation)
# Detect equation type
eq_type = self._detect_type(clean_eq)
# Convert to readable form
readable = self._to_readable(clean_eq)
# Extract variables
variables = self._extract_variables(clean_eq, context)
# Generate explanations
if settings.api_mode == 'api' and get_llm_config():
return await self._explain_with_llm(clean_eq, context, eq_type, readable, variables)
else:
return self._explain_local(clean_eq, eq_type, readable, variables)
def _clean_equation(self, equation: str) -> str:
"""Clean and normalize LaTeX equation."""
# Remove display mode markers
clean = equation.strip()
clean = re.sub(r'\\\[|\\\]', '', clean)
clean = re.sub(r'\$\$?', '', clean)
clean = re.sub(r'\\begin\{[^}]+\}|\\end\{[^}]+\}', '', clean)
clean = re.sub(r'\s+', ' ', clean)
return clean.strip()
def _detect_type(self, equation: str) -> str:
"""Detect the type of equation."""
eq_lower = equation.lower()
for eq_type, patterns in self.EQUATION_PATTERNS.items():
for pattern in patterns:
if re.search(pattern, equation, re.IGNORECASE):
return eq_type
# Check for common structures
if '=' in equation:
if re.search(r'\\frac', equation):
return 'definition'
return 'equation'
elif re.search(r'[<>≤≥]', equation):
return 'inequality'
return 'expression'
def _to_readable(self, equation: str) -> str:
"""Convert LaTeX to human-readable form."""
readable = equation
# Replace symbols
for latex, symbol in self.SYMBOL_MAP.items():
readable = readable.replace(latex, symbol)
# Handle fractions
readable = re.sub(r'\\frac\{([^}]+)\}\{([^}]+)\}', r'(\1)/(\2)', readable)
# Handle superscripts/subscripts
readable = re.sub(r'\^{([^}]+)}', r'^(\1)', readable)
readable = re.sub(r'_{([^}]+)}', r'_(\1)', readable)
readable = re.sub(r'\^(\w)', r'^(\1)', readable)
readable = re.sub(r'_(\w)', r'_(\1)', readable)
# Handle text
readable = re.sub(r'\\text\{([^}]+)\}', r'\1', readable)
readable = re.sub(r'\\mathrm\{([^}]+)\}', r'\1', readable)
readable = re.sub(r'\\mathbf\{([^}]+)\}', r'\1', readable)
# Handle sqrt
readable = re.sub(r'\\sqrt\{([^}]+)\}', r'√(\1)', readable)
# Clean remaining latex commands
readable = re.sub(r'\\[a-zA-Z]+', '', readable)
readable = re.sub(r'[{}]', '', readable)
readable = re.sub(r'\s+', ' ', readable)
return readable.strip()
def _extract_variables(
self,
equation: str,
context: Optional[str]
) -> List[VariableExplanation]:
"""Extract and explain variables from the equation."""
variables = []
found = set()
# Find Greek letters
for latex, symbol in self.SYMBOL_MAP.items():
if latex in equation and symbol not in found:
desc = self.COMMON_VARIABLES.get(symbol, "parameter")
variables.append(VariableExplanation(
symbol=symbol,
latex=latex,
description=desc
))
found.add(symbol)
# Find Latin letters (single letters)
latin_matches = re.findall(r'(?<![a-zA-Z\\])([a-zA-Z])(?![a-zA-Z])', equation)
for letter in latin_matches:
if letter not in found and letter not in ['d', 'e', 'i', 'f', 'g']:
desc = self.COMMON_VARIABLES.get(letter, "variable")
variables.append(VariableExplanation(
symbol=letter,
latex=letter,
description=desc
))
found.add(letter)
return variables[:10] # Limit to 10 most important
def _explain_local(
self,
equation: str,
eq_type: str,
readable: str,
variables: List[VariableExplanation]
) -> EquationExplanation:
"""Generate explanation without LLM."""
# Type-specific explanations
meaning_templates = {
'loss': "This is a loss function that measures the error between predictions and actual values.",
'gradient': "This computes the gradient (rate of change) of a function with respect to its parameters.",
'probability': "This represents a probability distribution or conditional probability.",
'expectation': "This calculates the expected value (average) over a probability distribution.",
'softmax': "This applies the softmax function to convert values into probabilities that sum to 1.",
'attention': "This computes attention weights to determine how much focus to give to different parts of the input.",
'norm': "This calculates a norm (magnitude) of a vector or matrix.",
'definition': "This defines a relationship or function between variables.",
'equation': "This establishes equality between two mathematical expressions.",
'inequality': "This describes a constraint or bound on values.",
'expression': "This is a mathematical expression involving the given variables."
}
importance_templates = {
'loss': "Loss functions are crucial for training ML models - they define what the model optimizes for.",
'gradient': "Gradients enable backpropagation, allowing the model to learn by adjusting parameters.",
'probability': "Probability formulations help the model reason about uncertainty and make predictions.",
'expectation': "Expected values help in optimization and understanding average model behavior.",
'softmax': "Softmax is fundamental for classification tasks, converting logits to class probabilities.",
'attention': "Attention mechanisms allow models to focus on relevant parts of input, key for transformers.",
'norm': "Norms help measure and control the magnitude of values, important for regularization."
}
meaning = meaning_templates.get(eq_type, meaning_templates['expression'])
importance = importance_templates.get(eq_type, "This equation contributes to the mathematical foundation of the method.")
return EquationExplanation(
readable=readable,
meaning=meaning,
variables=variables,
importance=importance,
equation_type=eq_type
)
async def _explain_with_llm(
self,
equation: str,
context: Optional[str],
eq_type: str,
readable: str,
variables: List[VariableExplanation]
) -> EquationExplanation:
"""Generate explanation using LLM."""
config = get_llm_config()
var_list = ", ".join([f"{v.symbol} ({v.description})" for v in variables])
prompt = f"""Explain this mathematical equation from a research paper:
LaTeX: {equation}
Readable form: {readable}
Equation type: {eq_type}
Variables: {var_list}
{f"Context: {context[:500]}" if context else ""}
Please provide:
1. A clear explanation of what this equation REPRESENTS (2-3 sentences)
2. Why this equation MATTERS in the context of ML/research (1-2 sentences)
Keep explanations concise and accurate. Format as:
MEANING: [explanation]
IMPORTANCE: [why it matters]"""
try:
async with httpx.AsyncClient(timeout=30.0) as client:
headers = {
"Authorization": f"Bearer {config['api_key']}",
"Content-Type": "application/json"
}
if config['provider'] == 'openrouter':
headers["HTTP-Referer"] = "https://ai-research-helper.local"
response = await client.post(
f"{config['base_url']}/chat/completions",
headers=headers,
json={
"model": config['model'],
"messages": [
{"role": "system", "content": "You are an expert ML researcher explaining equations. Be accurate and concise."},
{"role": "user", "content": prompt}
],
"temperature": 0.2,
"max_tokens": 400
}
)
response.raise_for_status()
result = response.json()['choices'][0]['message']['content']
# Parse response
meaning = ""
importance = ""
if "MEANING:" in result:
meaning = result.split("MEANING:")[1].split("IMPORTANCE:")[0].strip()
if "IMPORTANCE:" in result:
importance = result.split("IMPORTANCE:")[1].strip()
return EquationExplanation(
readable=readable,
meaning=meaning or self._explain_local(equation, eq_type, readable, variables).meaning,
variables=variables,
importance=importance or "This equation is part of the paper's mathematical framework.",
equation_type=eq_type
)
except Exception as e:
logger.error(f"LLM equation explanation failed: {e}")
return self._explain_local(equation, eq_type, readable, variables)
# Singleton instance
_equation_service = None
def get_equation_service() -> EquationExplainerService:
"""Get the singleton equation explainer service instance."""
global _equation_service
if _equation_service is None:
_equation_service = EquationExplainerService()
return _equation_service
|