|
|
""" |
|
|
Inference/Prediction Module |
|
|
Load fine-tuned model and extract entities from emails. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import re |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional, Union |
|
|
from dataclasses import dataclass |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class PredictionResult: |
|
|
"""Result of model prediction.""" |
|
|
entities: Dict |
|
|
raw_response: str |
|
|
success: bool |
|
|
error: Optional[str] = None |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
return { |
|
|
"entities": self.entities, |
|
|
"raw_response": self.raw_response, |
|
|
"success": self.success, |
|
|
"error": self.error |
|
|
} |
|
|
|
|
|
def to_json(self, indent: int = 2) -> str: |
|
|
return json.dumps(self.entities, indent=indent) |
|
|
|
|
|
|
|
|
class Predictor: |
|
|
""" |
|
|
Load and run inference with fine-tuned model. |
|
|
|
|
|
Supports both: |
|
|
- Base model + LoRA adapter |
|
|
- Merged model |
|
|
""" |
|
|
|
|
|
EXTRACTION_PROMPT = """Extract financial entities from this email: |
|
|
|
|
|
Subject: {subject} |
|
|
|
|
|
Body: {body}""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_path: Union[str, Path], |
|
|
adapter_path: Optional[Union[str, Path]] = None, |
|
|
max_tokens: int = 200 |
|
|
): |
|
|
""" |
|
|
Initialize predictor. |
|
|
|
|
|
Args: |
|
|
model_path: Path to model (base or merged) |
|
|
adapter_path: Optional path to LoRA adapter |
|
|
max_tokens: Maximum tokens to generate |
|
|
""" |
|
|
self.model_path = Path(model_path) |
|
|
self.adapter_path = Path(adapter_path) if adapter_path else None |
|
|
self.max_tokens = max_tokens |
|
|
|
|
|
self.model = None |
|
|
self.tokenizer = None |
|
|
self._loaded = False |
|
|
|
|
|
def load(self): |
|
|
"""Load the model and tokenizer.""" |
|
|
if self._loaded: |
|
|
return |
|
|
|
|
|
try: |
|
|
from mlx_lm import load |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"mlx_lm is required. Install with: pip install mlx-lm" |
|
|
) |
|
|
|
|
|
print(f"🔄 Loading model from {self.model_path}...") |
|
|
|
|
|
if self.adapter_path: |
|
|
print(f" With adapter: {self.adapter_path}") |
|
|
self.model, self.tokenizer = load( |
|
|
str(self.model_path), |
|
|
adapter_path=str(self.adapter_path) |
|
|
) |
|
|
else: |
|
|
self.model, self.tokenizer = load(str(self.model_path)) |
|
|
|
|
|
self._loaded = True |
|
|
print("✅ Model loaded successfully!") |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
subject: str = "", |
|
|
body: str = "", |
|
|
email_text: Optional[str] = None |
|
|
) -> PredictionResult: |
|
|
""" |
|
|
Extract entities from an email. |
|
|
|
|
|
Args: |
|
|
subject: Email subject |
|
|
body: Email body |
|
|
email_text: Full email text (alternative to subject+body) |
|
|
|
|
|
Returns: |
|
|
PredictionResult with extracted entities |
|
|
""" |
|
|
if not self._loaded: |
|
|
self.load() |
|
|
|
|
|
try: |
|
|
from mlx_lm import generate |
|
|
except ImportError: |
|
|
raise ImportError("mlx_lm is required") |
|
|
|
|
|
|
|
|
if email_text: |
|
|
prompt = f"Extract financial entities from this email:\n\n{email_text}" |
|
|
else: |
|
|
prompt = self.EXTRACTION_PROMPT.format( |
|
|
subject=subject[:200], |
|
|
body=body[:1500] |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
response = generate( |
|
|
self.model, |
|
|
self.tokenizer, |
|
|
prompt=prompt, |
|
|
max_tokens=self.max_tokens, |
|
|
verbose=False |
|
|
) |
|
|
except Exception as e: |
|
|
return PredictionResult( |
|
|
entities={}, |
|
|
raw_response="", |
|
|
success=False, |
|
|
error=f"Generation failed: {str(e)}" |
|
|
) |
|
|
|
|
|
|
|
|
entities = self._extract_json(response) |
|
|
|
|
|
return PredictionResult( |
|
|
entities=entities if entities else {}, |
|
|
raw_response=response, |
|
|
success=entities is not None |
|
|
) |
|
|
|
|
|
def predict_batch( |
|
|
self, |
|
|
emails: list |
|
|
) -> list: |
|
|
""" |
|
|
Extract entities from multiple emails. |
|
|
|
|
|
Args: |
|
|
emails: List of dicts with 'subject' and 'body' keys |
|
|
|
|
|
Returns: |
|
|
List of PredictionResults |
|
|
""" |
|
|
results = [] |
|
|
for email in emails: |
|
|
result = self.predict( |
|
|
subject=email.get('subject', ''), |
|
|
body=email.get('body', '') |
|
|
) |
|
|
results.append(result) |
|
|
return results |
|
|
|
|
|
def _extract_json(self, response: str) -> Optional[Dict]: |
|
|
"""Extract JSON object from model response.""" |
|
|
|
|
|
match = re.search(r'\{[^{}]*\}', response) |
|
|
|
|
|
if match: |
|
|
try: |
|
|
return json.loads(match.group()) |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""CLI for running predictions.""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
|
description="Extract financial entities from emails using fine-tuned LLM" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--model", |
|
|
type=str, |
|
|
default=str(Path.home() / "llm-mail-trainer/models/base/phi3-mini"), |
|
|
help="Path to model" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--adapter", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Path to LoRA adapter (optional)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--subject", |
|
|
type=str, |
|
|
default="", |
|
|
help="Email subject" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--body", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Email body text" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--file", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Path to file containing email text" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--interactive", |
|
|
action="store_true", |
|
|
help="Run in interactive mode" |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
predictor = Predictor( |
|
|
model_path=args.model, |
|
|
adapter_path=args.adapter |
|
|
) |
|
|
|
|
|
if args.interactive: |
|
|
run_interactive(predictor) |
|
|
elif args.file: |
|
|
with open(args.file, 'r') as f: |
|
|
text = f.read() |
|
|
predictor.load() |
|
|
result = predictor.predict(email_text=text) |
|
|
print(result.to_json()) |
|
|
elif args.body: |
|
|
predictor.load() |
|
|
result = predictor.predict(subject=args.subject, body=args.body) |
|
|
print(result.to_json()) |
|
|
else: |
|
|
parser.print_help() |
|
|
|
|
|
|
|
|
def run_interactive(predictor: Predictor): |
|
|
"""Interactive mode for testing.""" |
|
|
predictor.load() |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("🧠 LLM Mail Trainer - Interactive Mode") |
|
|
print("=" * 60) |
|
|
print("Enter email text to extract entities.") |
|
|
print("Type 'quit' or 'exit' to stop.") |
|
|
print("=" * 60 + "\n") |
|
|
|
|
|
while True: |
|
|
print("\n📧 Enter email text (multi-line, end with empty line):") |
|
|
lines = [] |
|
|
while True: |
|
|
try: |
|
|
line = input() |
|
|
if line.lower() in ['quit', 'exit']: |
|
|
print("\n👋 Goodbye!") |
|
|
return |
|
|
if line == "" and lines: |
|
|
break |
|
|
lines.append(line) |
|
|
except EOFError: |
|
|
print("\n👋 Goodbye!") |
|
|
return |
|
|
|
|
|
email_text = "\n".join(lines) |
|
|
|
|
|
if email_text.strip(): |
|
|
print("\n🔄 Extracting entities...") |
|
|
result = predictor.predict(email_text=email_text) |
|
|
|
|
|
print("\n📋 Extracted Entities:") |
|
|
print("-" * 40) |
|
|
print(result.to_json()) |
|
|
|
|
|
if not result.success: |
|
|
print(f"\n⚠️ Warning: {result.error or 'Could not parse JSON from response'}") |
|
|
print(f"Raw response: {result.raw_response[:200]}...") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|