fiewolf1000 commited on
Commit
c025244
·
verified ·
1 Parent(s): d8156e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -5
app.py CHANGED
@@ -97,18 +97,39 @@ class GPTResponse(BaseModel):
97
  usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
98
 
99
  # ------------------- 5. 加载 Cross-Encoder 模型(全局唯一实例) -------------------
 
100
  class CrossEncoderModel:
101
  def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
102
  self.model_name = model_name
103
- # 加载分词器和模型(从缓存目录加载,避免权限问题)
104
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
105
- self.model = AutoModelForSequenceClassification.from_pretrained(model_name)
106
- # 自动选择设备(GPU 优先,无则用 CPU)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
108
  self.model.to(self.device)
109
- self.model.eval() # 推理模式,关闭 Dropout
110
  print(f"模型加载完成!使用设备:{self.device}")
111
 
 
 
112
  def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool) -> List[DocumentScore]:
113
  """核心重排序逻辑:计算查询与文档的相关性并排序"""
114
  # 参数校验
 
97
  usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
98
 
99
  # ------------------- 5. 加载 Cross-Encoder 模型(全局唯一实例) -------------------
100
+ # 在 CrossEncoderModel 类的 __init__ 方法前添加缓存目录验证
101
  class CrossEncoderModel:
102
  def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
103
  self.model_name = model_name
104
+
105
+ # 【新增】验证缓存目录是否可写
106
+ cache_dir = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface_cache")
107
+ try:
108
+ # 尝试在缓存目录创建测试文件,验证权限
109
+ test_file = os.path.join(cache_dir, "test_write_permission.txt")
110
+ with open(test_file, "w") as f:
111
+ f.write("test")
112
+ os.remove(test_file) # 验证后删除测试文件
113
+ print(f"缓存目录权限验证通过:{cache_dir}")
114
+ except Exception as e:
115
+ raise RuntimeError(f"缓存目录不可写,请检查权限:{cache_dir},错误:{str(e)}")
116
+
117
+ # 加载模型(确保使用指定的缓存目录)
118
+ self.tokenizer = AutoTokenizer.from_pretrained(
119
+ model_name,
120
+ cache_dir=cache_dir # 显式指定缓存目录
121
+ )
122
+ self.model = AutoModelForSequenceClassification.from_pretrained(
123
+ model_name,
124
+ cache_dir=cache_dir # 显式指定缓存目录
125
+ )
126
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
127
  self.model.to(self.device)
128
+ self.model.eval()
129
  print(f"模型加载完成!使用设备:{self.device}")
130
 
131
+
132
+
133
  def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool) -> List[DocumentScore]:
134
  """核心重排序逻辑:计算查询与文档的相关性并排序"""
135
  # 参数校验