Update app.py
Browse files
app.py
CHANGED
|
@@ -22,8 +22,10 @@ from langgraph.graph import StateGraph, START, END
|
|
| 22 |
# ============================================================
|
| 23 |
# HUGGING FACE SPACES READY
|
| 24 |
# Medical CSV RAG Chatbot
|
| 25 |
-
#
|
| 26 |
-
#
|
|
|
|
|
|
|
| 27 |
# ============================================================
|
| 28 |
|
| 29 |
# -------------------------------
|
|
@@ -41,11 +43,11 @@ logger = logging.getLogger(__name__)
|
|
| 41 |
# -------------------------------
|
| 42 |
@dataclass
|
| 43 |
class Config:
|
|
|
|
| 44 |
base_model_path: str = os.getenv(
|
| 45 |
"BASE_MODEL_PATH",
|
| 46 |
"meta-llama/Llama-3.1-8B-Instruct"
|
| 47 |
)
|
| 48 |
-
|
| 49 |
adapter_dir: str = os.getenv(
|
| 50 |
"ADAPTER_DIR",
|
| 51 |
"adapter_refined_v10"
|
|
@@ -60,32 +62,39 @@ class Config:
|
|
| 60 |
)
|
| 61 |
vectorstore_dir: str = field(init=False)
|
| 62 |
|
|
|
|
| 63 |
hf_token: str = os.getenv("HF_TOKEN", "")
|
| 64 |
deepseek_api_key: str = os.getenv("DEEPSEEK_API_KEY", "")
|
| 65 |
deepseek_base_url: str = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
|
| 66 |
deepseek_model: str = os.getenv("DEEPSEEK_MODEL", "deepseek-chat")
|
| 67 |
|
|
|
|
| 68 |
deepseek_temperature: float = float(os.getenv("DEEPSEEK_TEMPERATURE", "0.1"))
|
| 69 |
deepseek_max_tokens: int = int(os.getenv("DEEPSEEK_MAX_TOKENS", "550"))
|
| 70 |
|
|
|
|
| 71 |
embed_model_name: str = os.getenv(
|
| 72 |
"EMBED_MODEL_NAME",
|
| 73 |
"sentence-transformers/all-MiniLM-L6-v2"
|
| 74 |
)
|
| 75 |
|
|
|
|
| 76 |
similarity_k: int = int(os.getenv("SIMILARITY_K", "12"))
|
| 77 |
top_k_final: int = int(os.getenv("TOP_K_FINAL", "4"))
|
| 78 |
max_context_chars: int = int(os.getenv("MAX_CONTEXT_CHARS", "5200"))
|
| 79 |
|
|
|
|
| 80 |
max_input_len: int = int(os.getenv("MAX_INPUT_LEN", "4096"))
|
| 81 |
max_new_tokens_local: int = int(os.getenv("MAX_NEW_TOKENS_LOCAL", "180"))
|
| 82 |
max_chat_history_turns: int = int(os.getenv("MAX_CHAT_HISTORY_TURNS", "6"))
|
| 83 |
|
|
|
|
| 84 |
min_lexical_overlap: float = float(os.getenv("MIN_LEXICAL_OVERLAP", "0.08"))
|
| 85 |
min_faiss_similarity: float = float(os.getenv("MIN_FAISS_SIMILARITY", "0.20"))
|
| 86 |
strong_retrieval_threshold: float = float(os.getenv("STRONG_RETRIEVAL_THRESHOLD", "0.30"))
|
| 87 |
strong_retrieval_min_docs: int = int(os.getenv("STRONG_RETRIEVAL_MIN_DOCS", "3"))
|
| 88 |
|
|
|
|
| 89 |
use_query_cache: bool = os.getenv("USE_QUERY_CACHE", "true").lower() == "true"
|
| 90 |
enable_query_expansion: bool = os.getenv("ENABLE_QUERY_EXPANSION", "true").lower() == "true"
|
| 91 |
enable_validator: bool = os.getenv("ENABLE_VALIDATOR", "true").lower() == "true"
|
|
@@ -93,12 +102,15 @@ class Config:
|
|
| 93 |
show_debug_panel: bool = os.getenv("SHOW_DEBUG_PANEL", "true").lower() == "true"
|
| 94 |
allow_rebuild_vectorstore: bool = os.getenv("ALLOW_REBUILD_VECTORSTORE", "false").lower() == "true"
|
| 95 |
|
|
|
|
| 96 |
use_4bit: bool = os.getenv("USE_4BIT", "true").lower() == "true"
|
| 97 |
|
|
|
|
| 98 |
launch_debug: bool = os.getenv("LAUNCH_DEBUG", "false").lower() == "true"
|
| 99 |
server_name: str = os.getenv("SERVER_NAME", "0.0.0.0")
|
| 100 |
server_port: int = int(os.getenv("SERVER_PORT", "7860"))
|
| 101 |
|
|
|
|
| 102 |
blink_stage_1: float = float(os.getenv("BLINK_STAGE_1", "0.40"))
|
| 103 |
blink_stage_2: float = float(os.getenv("BLINK_STAGE_2", "0.55"))
|
| 104 |
blink_stage_3: float = float(os.getenv("BLINK_STAGE_3", "0.50"))
|
|
|
|
| 22 |
# ============================================================
|
| 23 |
# HUGGING FACE SPACES READY
|
| 24 |
# Medical CSV RAG Chatbot
|
| 25 |
+
# Optimized pipeline:
|
| 26 |
+
# RAG retrieval -> local ECG adapter reasoning -> grounded summary
|
| 27 |
+
# UI goal:
|
| 28 |
+
# polished mobile-friendly chatbot UX with minimal sources panel
|
| 29 |
# ============================================================
|
| 30 |
|
| 31 |
# -------------------------------
|
|
|
|
| 43 |
# -------------------------------
|
| 44 |
@dataclass
|
| 45 |
class Config:
|
| 46 |
+
# Paths
|
| 47 |
base_model_path: str = os.getenv(
|
| 48 |
"BASE_MODEL_PATH",
|
| 49 |
"meta-llama/Llama-3.1-8B-Instruct"
|
| 50 |
)
|
|
|
|
| 51 |
adapter_dir: str = os.getenv(
|
| 52 |
"ADAPTER_DIR",
|
| 53 |
"adapter_refined_v10"
|
|
|
|
| 62 |
)
|
| 63 |
vectorstore_dir: str = field(init=False)
|
| 64 |
|
| 65 |
+
# Auth / APIs
|
| 66 |
hf_token: str = os.getenv("HF_TOKEN", "")
|
| 67 |
deepseek_api_key: str = os.getenv("DEEPSEEK_API_KEY", "")
|
| 68 |
deepseek_base_url: str = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
|
| 69 |
deepseek_model: str = os.getenv("DEEPSEEK_MODEL", "deepseek-chat")
|
| 70 |
|
| 71 |
+
# DeepSeek generation
|
| 72 |
deepseek_temperature: float = float(os.getenv("DEEPSEEK_TEMPERATURE", "0.1"))
|
| 73 |
deepseek_max_tokens: int = int(os.getenv("DEEPSEEK_MAX_TOKENS", "550"))
|
| 74 |
|
| 75 |
+
# Embeddings
|
| 76 |
embed_model_name: str = os.getenv(
|
| 77 |
"EMBED_MODEL_NAME",
|
| 78 |
"sentence-transformers/all-MiniLM-L6-v2"
|
| 79 |
)
|
| 80 |
|
| 81 |
+
# Retrieval
|
| 82 |
similarity_k: int = int(os.getenv("SIMILARITY_K", "12"))
|
| 83 |
top_k_final: int = int(os.getenv("TOP_K_FINAL", "4"))
|
| 84 |
max_context_chars: int = int(os.getenv("MAX_CONTEXT_CHARS", "5200"))
|
| 85 |
|
| 86 |
+
# Generation
|
| 87 |
max_input_len: int = int(os.getenv("MAX_INPUT_LEN", "4096"))
|
| 88 |
max_new_tokens_local: int = int(os.getenv("MAX_NEW_TOKENS_LOCAL", "180"))
|
| 89 |
max_chat_history_turns: int = int(os.getenv("MAX_CHAT_HISTORY_TURNS", "6"))
|
| 90 |
|
| 91 |
+
# Filtering
|
| 92 |
min_lexical_overlap: float = float(os.getenv("MIN_LEXICAL_OVERLAP", "0.08"))
|
| 93 |
min_faiss_similarity: float = float(os.getenv("MIN_FAISS_SIMILARITY", "0.20"))
|
| 94 |
strong_retrieval_threshold: float = float(os.getenv("STRONG_RETRIEVAL_THRESHOLD", "0.30"))
|
| 95 |
strong_retrieval_min_docs: int = int(os.getenv("STRONG_RETRIEVAL_MIN_DOCS", "3"))
|
| 96 |
|
| 97 |
+
# Features
|
| 98 |
use_query_cache: bool = os.getenv("USE_QUERY_CACHE", "true").lower() == "true"
|
| 99 |
enable_query_expansion: bool = os.getenv("ENABLE_QUERY_EXPANSION", "true").lower() == "true"
|
| 100 |
enable_validator: bool = os.getenv("ENABLE_VALIDATOR", "true").lower() == "true"
|
|
|
|
| 102 |
show_debug_panel: bool = os.getenv("SHOW_DEBUG_PANEL", "true").lower() == "true"
|
| 103 |
allow_rebuild_vectorstore: bool = os.getenv("ALLOW_REBUILD_VECTORSTORE", "false").lower() == "true"
|
| 104 |
|
| 105 |
+
# Model loading
|
| 106 |
use_4bit: bool = os.getenv("USE_4BIT", "true").lower() == "true"
|
| 107 |
|
| 108 |
+
# Launch
|
| 109 |
launch_debug: bool = os.getenv("LAUNCH_DEBUG", "false").lower() == "true"
|
| 110 |
server_name: str = os.getenv("SERVER_NAME", "0.0.0.0")
|
| 111 |
server_port: int = int(os.getenv("SERVER_PORT", "7860"))
|
| 112 |
|
| 113 |
+
# UI timings
|
| 114 |
blink_stage_1: float = float(os.getenv("BLINK_STAGE_1", "0.40"))
|
| 115 |
blink_stage_2: float = float(os.getenv("BLINK_STAGE_2", "0.55"))
|
| 116 |
blink_stage_3: float = float(os.getenv("BLINK_STAGE_3", "0.50"))
|