Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -97,18 +97,39 @@ class GPTResponse(BaseModel):
|
|
| 97 |
usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
| 98 |
|
| 99 |
# ------------------- 5. 加载 Cross-Encoder 模型(全局唯一实例) -------------------
|
|
|
|
| 100 |
class CrossEncoderModel:
|
| 101 |
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
| 102 |
self.model_name = model_name
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 108 |
self.model.to(self.device)
|
| 109 |
-
self.model.eval()
|
| 110 |
print(f"模型加载完成!使用设备:{self.device}")
|
| 111 |
|
|
|
|
|
|
|
| 112 |
def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool) -> List[DocumentScore]:
|
| 113 |
"""核心重排序逻辑:计算查询与文档的相关性并排序"""
|
| 114 |
# 参数校验
|
|
|
|
| 97 |
usage: dict = {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
| 98 |
|
| 99 |
# ------------------- 5. 加载 Cross-Encoder 模型(全局唯一实例) -------------------
|
| 100 |
+
# 在 CrossEncoderModel 类的 __init__ 方法前添加缓存目录验证
|
| 101 |
class CrossEncoderModel:
|
| 102 |
def __init__(self, model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"):
|
| 103 |
self.model_name = model_name
|
| 104 |
+
|
| 105 |
+
# 【新增】验证缓存目录是否可写
|
| 106 |
+
cache_dir = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface_cache")
|
| 107 |
+
try:
|
| 108 |
+
# 尝试在缓存目录创建测试文件,验证权限
|
| 109 |
+
test_file = os.path.join(cache_dir, "test_write_permission.txt")
|
| 110 |
+
with open(test_file, "w") as f:
|
| 111 |
+
f.write("test")
|
| 112 |
+
os.remove(test_file) # 验证后删除测试文件
|
| 113 |
+
print(f"缓存目录权限验证通过:{cache_dir}")
|
| 114 |
+
except Exception as e:
|
| 115 |
+
raise RuntimeError(f"缓存目录不可写,请检查权限:{cache_dir},错误:{str(e)}")
|
| 116 |
+
|
| 117 |
+
# 加载模型(确保使用指定的缓存目录)
|
| 118 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 119 |
+
model_name,
|
| 120 |
+
cache_dir=cache_dir # 显式指定缓存目录
|
| 121 |
+
)
|
| 122 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 123 |
+
model_name,
|
| 124 |
+
cache_dir=cache_dir # 显式指定缓存目录
|
| 125 |
+
)
|
| 126 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 127 |
self.model.to(self.device)
|
| 128 |
+
self.model.eval()
|
| 129 |
print(f"模型加载完成!使用设备:{self.device}")
|
| 130 |
|
| 131 |
+
|
| 132 |
+
|
| 133 |
def rerank(self, query: str, documents: List[str], top_k: int, truncation: bool) -> List[DocumentScore]:
|
| 134 |
"""核心重排序逻辑:计算查询与文档的相关性并排序"""
|
| 135 |
# 参数校验
|