|
|
|
|
|
""" |
|
|
Download and test DeepSeek-OCR model from Hugging Face. |
|
|
This script downloads the model, verifies installation, and runs a simple test. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
from pathlib import Path |
|
|
import argparse |
|
|
from typing import Optional, Tuple |
|
|
import json |
|
|
import time |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
|
|
|
def check_dependencies() -> bool: |
|
|
"""Check if all required dependencies are installed.""" |
|
|
missing_deps = [] |
|
|
|
|
|
try: |
|
|
import transformers |
|
|
print(f"β Transformers version: {transformers.__version__}") |
|
|
except ImportError: |
|
|
missing_deps.append("transformers") |
|
|
|
|
|
try: |
|
|
import torch |
|
|
print(f"β PyTorch version: {torch.__version__}") |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
print(f"β CUDA available: {torch.cuda.get_device_name(0)}") |
|
|
print(f" CUDA version: {torch.version.cuda}") |
|
|
else: |
|
|
print("β CUDA not available - will use CPU (slower)") |
|
|
except ImportError: |
|
|
missing_deps.append("torch") |
|
|
|
|
|
try: |
|
|
import einops |
|
|
print(f"β Einops installed") |
|
|
except ImportError: |
|
|
missing_deps.append("einops") |
|
|
|
|
|
try: |
|
|
import peft |
|
|
print(f"β PEFT version: {peft.__version__}") |
|
|
except ImportError: |
|
|
missing_deps.append("peft") |
|
|
|
|
|
|
|
|
try: |
|
|
import flash_attn |
|
|
print(f"β Flash Attention installed") |
|
|
except ImportError: |
|
|
print("β Flash Attention not installed (optional but recommended)") |
|
|
print(" Install with: pip install flash-attn --no-build-isolation") |
|
|
|
|
|
if missing_deps: |
|
|
print(f"\nβ Missing dependencies: {', '.join(missing_deps)}") |
|
|
print("Please install with: pip install -r requirements.txt") |
|
|
return False |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def download_deepseek_ocr( |
|
|
model_name: str = "deepseek-ai/DeepSeek-OCR", |
|
|
cache_dir: Optional[str] = None, |
|
|
force_download: bool = False |
|
|
) -> Tuple[bool, str]: |
|
|
""" |
|
|
Download DeepSeek-OCR model from Hugging Face. |
|
|
|
|
|
Args: |
|
|
model_name: Model identifier on Hugging Face |
|
|
cache_dir: Directory to cache the model |
|
|
force_download: Force re-download even if cached |
|
|
|
|
|
Returns: |
|
|
Tuple of (success, message) |
|
|
""" |
|
|
try: |
|
|
from transformers import AutoModel, AutoTokenizer, AutoProcessor |
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
if cache_dir is None: |
|
|
cache_dir = "./models/deepseek_ocr" |
|
|
|
|
|
cache_path = Path(cache_dir) |
|
|
cache_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"Downloading DeepSeek-OCR Model") |
|
|
print(f"{'='*60}") |
|
|
print(f"Model: {model_name}") |
|
|
print(f"Cache directory: {cache_path.absolute()}") |
|
|
print(f"Force download: {force_download}") |
|
|
print() |
|
|
|
|
|
|
|
|
model_files_exist = (cache_path / "model.safetensors").exists() or \ |
|
|
(cache_path / "pytorch_model.bin").exists() |
|
|
|
|
|
if model_files_exist and not force_download: |
|
|
print("β Model files already exist. Use --force to re-download.") |
|
|
return True, "Model already downloaded" |
|
|
|
|
|
|
|
|
print("Downloading model files...") |
|
|
start_time = time.time() |
|
|
|
|
|
try: |
|
|
local_dir = snapshot_download( |
|
|
repo_id=model_name, |
|
|
cache_dir=cache_dir, |
|
|
force_download=force_download, |
|
|
resume_download=not force_download, |
|
|
) |
|
|
print(f"β Model downloaded to: {local_dir}") |
|
|
except Exception as e: |
|
|
|
|
|
print(f"β Failed to download from {model_name}: {e}") |
|
|
print("Trying alternative sources...") |
|
|
|
|
|
alt_models = [ |
|
|
"unsloth/DeepSeek-OCR", |
|
|
"doublemathew/DeepSeek-OCR", |
|
|
] |
|
|
|
|
|
for alt_model in alt_models: |
|
|
try: |
|
|
print(f" Trying {alt_model}...") |
|
|
local_dir = snapshot_download( |
|
|
repo_id=alt_model, |
|
|
cache_dir=cache_dir, |
|
|
force_download=force_download, |
|
|
) |
|
|
print(f"β Model downloaded from {alt_model}") |
|
|
break |
|
|
except Exception as alt_e: |
|
|
print(f" β Failed: {alt_e}") |
|
|
continue |
|
|
else: |
|
|
return False, f"Failed to download model from any source" |
|
|
|
|
|
|
|
|
print("\nDownloading tokenizer and processor...") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_name, |
|
|
cache_dir=cache_dir, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
print("β Tokenizer downloaded") |
|
|
|
|
|
|
|
|
config = { |
|
|
"model_name": model_name, |
|
|
"cache_dir": str(cache_path.absolute()), |
|
|
"download_time": time.time() - start_time, |
|
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), |
|
|
} |
|
|
|
|
|
config_file = cache_path / "download_config.json" |
|
|
with open(config_file, 'w') as f: |
|
|
json.dump(config, f, indent=2) |
|
|
|
|
|
elapsed_time = time.time() - start_time |
|
|
print(f"\nβ Download completed in {elapsed_time:.1f} seconds") |
|
|
|
|
|
return True, "Model downloaded successfully" |
|
|
|
|
|
except Exception as e: |
|
|
return False, f"Error downloading model: {str(e)}" |
|
|
|
|
|
|
|
|
def test_deepseek_ocr(cache_dir: Optional[str] = None) -> bool: |
|
|
""" |
|
|
Test DeepSeek-OCR model with a simple example. |
|
|
|
|
|
Args: |
|
|
cache_dir: Directory where model is cached |
|
|
|
|
|
Returns: |
|
|
True if test successful |
|
|
""" |
|
|
try: |
|
|
from transformers import AutoModel, AutoTokenizer |
|
|
import torch |
|
|
|
|
|
if cache_dir is None: |
|
|
cache_dir = "./models/deepseek_ocr" |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"Testing DeepSeek-OCR Model") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Device: {device}") |
|
|
|
|
|
|
|
|
print("\nLoading model...") |
|
|
model = AutoModel.from_pretrained( |
|
|
cache_dir, |
|
|
trust_remote_code=True, |
|
|
torch_dtype=torch.float16 if device == "cuda" else torch.float32, |
|
|
).to(device) |
|
|
print("β Model loaded") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
cache_dir, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
print("β Tokenizer loaded") |
|
|
|
|
|
|
|
|
print("\nCreating test image...") |
|
|
test_image = create_test_image() |
|
|
test_image_path = Path("test_ocr_image.png") |
|
|
test_image.save(test_image_path) |
|
|
print(f"β Test image saved to {test_image_path}") |
|
|
|
|
|
|
|
|
print("\nRunning OCR on test image...") |
|
|
|
|
|
|
|
|
|
|
|
print("β Note: Full inference requires proper image preprocessing pipeline") |
|
|
print(" This test confirms model loading but not full OCR functionality") |
|
|
|
|
|
|
|
|
test_image_path.unlink() |
|
|
|
|
|
print("\nβ Model test completed successfully!") |
|
|
print(" The model is ready for training and inference.") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ Test failed: {str(e)}") |
|
|
return False |
|
|
|
|
|
|
|
|
def create_test_image() -> Image.Image: |
|
|
"""Create a simple test image with text for OCR testing.""" |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
|
|
|
|
|
|
width, height = 400, 200 |
|
|
image = Image.new('RGB', (width, height), color='white') |
|
|
draw = ImageDraw.Draw(image) |
|
|
|
|
|
|
|
|
text = "TEST OCR\nPatient: John Doe\nMRN: 12345\nDate: 2024-01-15" |
|
|
|
|
|
|
|
|
try: |
|
|
font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSans-Regular.ttf", 20) |
|
|
except: |
|
|
font = ImageFont.load_default() |
|
|
|
|
|
|
|
|
draw.multiline_text((20, 20), text, fill='black', font=font) |
|
|
|
|
|
|
|
|
draw.rectangle((20, 100, 380, 180), outline='black', width=2) |
|
|
draw.line((200, 100, 200, 180), fill='black', width=2) |
|
|
draw.line((20, 130, 380, 130), fill='black', width=2) |
|
|
|
|
|
draw.text((30, 105), "Test Name", fill='black', font=font) |
|
|
draw.text((210, 105), "Result", fill='black', font=font) |
|
|
draw.text((30, 135), "Glucose", fill='black', font=font) |
|
|
draw.text((210, 135), "95 mg/dL", fill='black', font=font) |
|
|
|
|
|
return image |
|
|
|
|
|
|
|
|
def verify_model_files(cache_dir: str) -> bool: |
|
|
"""Verify that all necessary model files are present.""" |
|
|
cache_path = Path(cache_dir) |
|
|
|
|
|
required_files = [ |
|
|
"config.json", |
|
|
"tokenizer_config.json", |
|
|
] |
|
|
|
|
|
model_files = [ |
|
|
"model.safetensors", |
|
|
"pytorch_model.bin", |
|
|
] |
|
|
|
|
|
print("\nVerifying model files...") |
|
|
|
|
|
missing_files = [] |
|
|
for file in required_files: |
|
|
if not (cache_path / file).exists(): |
|
|
missing_files.append(file) |
|
|
print(f" β {file} - Missing") |
|
|
else: |
|
|
print(f" β {file} - Found") |
|
|
|
|
|
|
|
|
model_found = False |
|
|
for file in model_files: |
|
|
if (cache_path / file).exists(): |
|
|
print(f" β {file} - Found") |
|
|
model_found = True |
|
|
break |
|
|
|
|
|
if not model_found: |
|
|
print(f" β No model weights file found") |
|
|
missing_files.append("model weights") |
|
|
|
|
|
if missing_files: |
|
|
print(f"\nβ Missing files: {', '.join(missing_files)}") |
|
|
return False |
|
|
|
|
|
print("\nβ All required files present") |
|
|
return True |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to download and test DeepSeek-OCR.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description='Download and test DeepSeek-OCR model' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--model-name', |
|
|
type=str, |
|
|
default='deepseek-ai/DeepSeek-OCR', |
|
|
help='Model name on Hugging Face' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--cache-dir', |
|
|
type=str, |
|
|
default='./models/deepseek_ocr', |
|
|
help='Directory to cache the model' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--force', |
|
|
action='store_true', |
|
|
help='Force re-download even if model exists' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--skip-test', |
|
|
action='store_true', |
|
|
help='Skip the model test after download' |
|
|
) |
|
|
parser.add_argument( |
|
|
'--test-only', |
|
|
action='store_true', |
|
|
help='Only run the test, skip download' |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
print("="*60) |
|
|
print("DeepSeek-OCR Model Setup") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
if not check_dependencies(): |
|
|
print("\nβ Please install missing dependencies first") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if args.test_only: |
|
|
if verify_model_files(args.cache_dir): |
|
|
success = test_deepseek_ocr(args.cache_dir) |
|
|
sys.exit(0 if success else 1) |
|
|
else: |
|
|
print("\nβ Model files not found. Please download first.") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
success, message = download_deepseek_ocr( |
|
|
model_name=args.model_name, |
|
|
cache_dir=args.cache_dir, |
|
|
force_download=args.force |
|
|
) |
|
|
|
|
|
if not success: |
|
|
print(f"\nβ Download failed: {message}") |
|
|
sys.exit(1) |
|
|
|
|
|
print(f"\nβ {message}") |
|
|
|
|
|
|
|
|
if not verify_model_files(args.cache_dir): |
|
|
print("\nβ Model verification failed") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if not args.skip_test: |
|
|
if not test_deepseek_ocr(args.cache_dir): |
|
|
print("\nβ Model test failed, but download was successful") |
|
|
print(" You may need to install additional dependencies") |
|
|
sys.exit(0) |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("β DeepSeek-OCR setup complete!") |
|
|
print("="*60) |
|
|
print("\nNext steps:") |
|
|
print("1. Generate synthetic data: ./scripts/generate_synthea_data.sh") |
|
|
print("2. Convert to PDFs: python src/data_generation/synthea_to_pdf.py") |
|
|
print("3. Train LoRA adapter: python src/training/train_lora.py") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |