samwell Claude commited on
Commit
55abc5d
·
1 Parent(s): 4df7c2f

Fix remaining tool loading issues

Browse files

- Remove BitsAndBytesConfig import from NV-Reason-CXR (causing dict.to_dict() error)
- Remove load_in_4bit parameter from Classification and Report Generation tools
- Update timm to 0.9.16 to fix ImageNetInfo import error for VQA tool
- All tools should now load properly

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (3) hide show
  1. app.py +2 -4
  2. medrax/tools/nv_reason_cxr.py +1 -1
  3. requirements.txt +1 -1
app.py CHANGED
@@ -72,8 +72,7 @@ if device == "cuda":
72
  try:
73
  from medrax.tools.classification import TorchXRayVisionClassifierTool
74
  classification_tool = TorchXRayVisionClassifierTool(
75
- device=device,
76
- load_in_4bit=True
77
  )
78
  tools.append(classification_tool)
79
  print("✓ Loaded classification tool")
@@ -83,8 +82,7 @@ if device == "cuda":
83
  try:
84
  from medrax.tools.report_generation import ChestXRayReportGeneratorTool
85
  report_tool = ChestXRayReportGeneratorTool(
86
- device=device,
87
- load_in_4bit=True
88
  )
89
  tools.append(report_tool)
90
  print("✓ Loaded report generation tool")
 
72
  try:
73
  from medrax.tools.classification import TorchXRayVisionClassifierTool
74
  classification_tool = TorchXRayVisionClassifierTool(
75
+ device=device
 
76
  )
77
  tools.append(classification_tool)
78
  print("✓ Loaded classification tool")
 
82
  try:
83
  from medrax.tools.report_generation import ChestXRayReportGeneratorTool
84
  report_tool = ChestXRayReportGeneratorTool(
85
+ device=device
 
86
  )
87
  tools.append(report_tool)
88
  print("✓ Loaded report generation tool")
medrax/tools/nv_reason_cxr.py CHANGED
@@ -4,7 +4,7 @@ from pathlib import Path
4
  import torch
5
  from PIL import Image
6
  from pydantic import BaseModel, Field
7
- from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig
8
 
9
  from langchain_core.callbacks import (
10
  AsyncCallbackManagerForToolRun,
 
4
  import torch
5
  from PIL import Image
6
  from pydantic import BaseModel, Field
7
+ from transformers import AutoProcessor, AutoModelForImageTextToText
8
 
9
  from langchain_core.callbacks import (
10
  AsyncCallbackManagerForToolRun,
requirements.txt CHANGED
@@ -41,7 +41,7 @@ fastapi>=0.68.0
41
  python-multipart>=0.0.6
42
  einops>=0.3.0
43
  einops-exts>=0.0.4
44
- timm==0.5.4
45
  tiktoken>=0.3.0
46
  openai>=0.27.0
47
  backoff>=1.10.0
 
41
  python-multipart>=0.0.6
42
  einops>=0.3.0
43
  einops-exts>=0.0.4
44
+ timm==0.9.16
45
  tiktoken>=0.3.0
46
  openai>=0.27.0
47
  backoff>=1.10.0