Spaces:
Runtime error
Runtime error
Commit Β·
c8efeea
1
Parent(s): 0a42205
.....
Browse files- core/config.py +1 -1
- services/ml_service.py +14 -22
core/config.py
CHANGED
|
@@ -54,7 +54,7 @@ class Settings(BaseSettings):
|
|
| 54 |
ENABLE_AUTO_LANGUAGE_DETECTION: bool = True
|
| 55 |
|
| 56 |
# Safety & Compliance
|
| 57 |
-
ENABLE_RED_FLAG_DETECTION: bool = True
|
| 58 |
REQUIRE_DISCLAIMER: bool = True
|
| 59 |
LOG_ANONYMIZATION: bool = True
|
| 60 |
|
|
|
|
| 54 |
ENABLE_AUTO_LANGUAGE_DETECTION: bool = True
|
| 55 |
|
| 56 |
# Safety & Compliance
|
| 57 |
+
# ENABLE_RED_FLAG_DETECTION: bool = True
|
| 58 |
REQUIRE_DISCLAIMER: bool = True
|
| 59 |
LOG_ANONYMIZATION: bool = True
|
| 60 |
|
services/ml_service.py
CHANGED
|
@@ -6,18 +6,16 @@ from services.natlas_service import NATLaSService
|
|
| 6 |
|
| 7 |
class MLService:
|
| 8 |
"""ML Service with N-ATLaS and embeddings"""
|
| 9 |
-
|
| 10 |
def __init__(self):
|
| 11 |
self.embedding_model = None
|
| 12 |
self.natlas_service = None
|
| 13 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
-
|
| 15 |
-
self.use_natlas_flag = True
|
| 16 |
-
|
| 17 |
async def initialize(self):
|
| 18 |
"""Initialize all ML services"""
|
| 19 |
print(f"π€ Initializing ML Services on {self.device}")
|
| 20 |
-
|
| 21 |
# Load embedding model
|
| 22 |
print(f"π Loading: {settings.EMBEDDING_MODEL}")
|
| 23 |
self.embedding_model = SentenceTransformer(
|
|
@@ -25,28 +23,28 @@ class MLService:
|
|
| 25 |
device=self.device
|
| 26 |
)
|
| 27 |
print("β
Embedding model loaded")
|
| 28 |
-
|
| 29 |
# Initialize N-ATLaS
|
| 30 |
self.natlas_service = NATLaSService()
|
| 31 |
await self.natlas_service.initialize()
|
| 32 |
-
|
| 33 |
async def generate_embedding(self, text: str, language: Optional[str] = None) -> List[float]:
|
| 34 |
"""Generate embedding"""
|
| 35 |
if self.embedding_model is None:
|
| 36 |
raise RuntimeError("Embedding model not initialized")
|
| 37 |
-
|
| 38 |
embedding = self.embedding_model.encode(
|
| 39 |
text.strip(),
|
| 40 |
convert_to_numpy=True,
|
| 41 |
normalize_embeddings=True
|
| 42 |
)
|
| 43 |
return embedding.tolist()
|
| 44 |
-
|
| 45 |
async def generate_embeddings_batch(self, texts: List[str]) -> List[List[float]]:
|
| 46 |
"""Generate embeddings in batch"""
|
| 47 |
if self.embedding_model is None:
|
| 48 |
raise RuntimeError("Embedding model not initialized")
|
| 49 |
-
|
| 50 |
embeddings = self.embedding_model.encode(
|
| 51 |
[t.strip() for t in texts],
|
| 52 |
convert_to_numpy=True,
|
|
@@ -54,26 +52,20 @@ class MLService:
|
|
| 54 |
normalize_embeddings=True
|
| 55 |
)
|
| 56 |
return embeddings.tolist()
|
| 57 |
-
|
| 58 |
async def analyze_with_natlas(self, symptoms: str, language: str = "en") -> str:
|
| 59 |
"""Use N-ATLaS for analysis"""
|
| 60 |
-
if not self.use_natlas_flag:
|
| 61 |
-
return "N-ATLaS analysis disabled."
|
| 62 |
-
if self.natlas_service is None:
|
| 63 |
-
raise RuntimeError("N-ATLaS service not initialized")
|
| 64 |
return await self.natlas_service.analyze_symptoms(symptoms, language)
|
| 65 |
-
|
| 66 |
def detect_language(self, text: str) -> str:
|
| 67 |
"""Detect language"""
|
| 68 |
-
if self.natlas_service is None:
|
| 69 |
-
raise RuntimeError("N-ATLaS service not initialized")
|
| 70 |
return self.natlas_service.detect_language(text)
|
| 71 |
-
|
| 72 |
def get_model_info(self) -> dict:
|
| 73 |
"""Get model information"""
|
| 74 |
return {
|
| 75 |
"embedding_model": settings.EMBEDDING_MODEL,
|
| 76 |
"device": self.device,
|
| 77 |
-
"dimension": self.embedding_model.get_sentence_embedding_dimension()
|
| 78 |
-
"natlas": self.natlas_service.get_model_info()
|
| 79 |
-
}
|
|
|
|
| 6 |
|
| 7 |
class MLService:
|
| 8 |
"""ML Service with N-ATLaS and embeddings"""
|
| 9 |
+
|
| 10 |
def __init__(self):
|
| 11 |
self.embedding_model = None
|
| 12 |
self.natlas_service = None
|
| 13 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
+
|
|
|
|
|
|
|
| 15 |
async def initialize(self):
|
| 16 |
"""Initialize all ML services"""
|
| 17 |
print(f"π€ Initializing ML Services on {self.device}")
|
| 18 |
+
|
| 19 |
# Load embedding model
|
| 20 |
print(f"π Loading: {settings.EMBEDDING_MODEL}")
|
| 21 |
self.embedding_model = SentenceTransformer(
|
|
|
|
| 23 |
device=self.device
|
| 24 |
)
|
| 25 |
print("β
Embedding model loaded")
|
| 26 |
+
|
| 27 |
# Initialize N-ATLaS
|
| 28 |
self.natlas_service = NATLaSService()
|
| 29 |
await self.natlas_service.initialize()
|
| 30 |
+
|
| 31 |
async def generate_embedding(self, text: str, language: Optional[str] = None) -> List[float]:
|
| 32 |
"""Generate embedding"""
|
| 33 |
if self.embedding_model is None:
|
| 34 |
raise RuntimeError("Embedding model not initialized")
|
| 35 |
+
|
| 36 |
embedding = self.embedding_model.encode(
|
| 37 |
text.strip(),
|
| 38 |
convert_to_numpy=True,
|
| 39 |
normalize_embeddings=True
|
| 40 |
)
|
| 41 |
return embedding.tolist()
|
| 42 |
+
|
| 43 |
async def generate_embeddings_batch(self, texts: List[str]) -> List[List[float]]:
|
| 44 |
"""Generate embeddings in batch"""
|
| 45 |
if self.embedding_model is None:
|
| 46 |
raise RuntimeError("Embedding model not initialized")
|
| 47 |
+
|
| 48 |
embeddings = self.embedding_model.encode(
|
| 49 |
[t.strip() for t in texts],
|
| 50 |
convert_to_numpy=True,
|
|
|
|
| 52 |
normalize_embeddings=True
|
| 53 |
)
|
| 54 |
return embeddings.tolist()
|
| 55 |
+
|
| 56 |
async def analyze_with_natlas(self, symptoms: str, language: str = "en") -> str:
|
| 57 |
"""Use N-ATLaS for analysis"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
return await self.natlas_service.analyze_symptoms(symptoms, language)
|
| 59 |
+
|
| 60 |
def detect_language(self, text: str) -> str:
|
| 61 |
"""Detect language"""
|
|
|
|
|
|
|
| 62 |
return self.natlas_service.detect_language(text)
|
| 63 |
+
|
| 64 |
def get_model_info(self) -> dict:
|
| 65 |
"""Get model information"""
|
| 66 |
return {
|
| 67 |
"embedding_model": settings.EMBEDDING_MODEL,
|
| 68 |
"device": self.device,
|
| 69 |
+
"dimension": self.embedding_model.get_sentence_embedding_dimension(),
|
| 70 |
+
"natlas": self.natlas_service.get_model_info()
|
| 71 |
+
}
|