resberry commited on
Commit
fa8ce00
·
verified ·
1 Parent(s): 41aa811

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -3
app.py CHANGED
@@ -22,8 +22,10 @@ from langgraph.graph import StateGraph, START, END
22
  # ============================================================
23
  # HUGGING FACE SPACES READY
24
  # Medical CSV RAG Chatbot
25
- # Mobile-friendly UI/UX version
26
- # Pipeline: RAG retrieval -> local ECG adapter reasoning -> grounded summary
 
 
27
  # ============================================================
28
 
29
  # -------------------------------
@@ -41,11 +43,11 @@ logger = logging.getLogger(__name__)
41
  # -------------------------------
42
  @dataclass
43
  class Config:
 
44
  base_model_path: str = os.getenv(
45
  "BASE_MODEL_PATH",
46
  "meta-llama/Llama-3.1-8B-Instruct"
47
  )
48
-
49
  adapter_dir: str = os.getenv(
50
  "ADAPTER_DIR",
51
  "adapter_refined_v10"
@@ -60,32 +62,39 @@ class Config:
60
  )
61
  vectorstore_dir: str = field(init=False)
62
 
 
63
  hf_token: str = os.getenv("HF_TOKEN", "")
64
  deepseek_api_key: str = os.getenv("DEEPSEEK_API_KEY", "")
65
  deepseek_base_url: str = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
66
  deepseek_model: str = os.getenv("DEEPSEEK_MODEL", "deepseek-chat")
67
 
 
68
  deepseek_temperature: float = float(os.getenv("DEEPSEEK_TEMPERATURE", "0.1"))
69
  deepseek_max_tokens: int = int(os.getenv("DEEPSEEK_MAX_TOKENS", "550"))
70
 
 
71
  embed_model_name: str = os.getenv(
72
  "EMBED_MODEL_NAME",
73
  "sentence-transformers/all-MiniLM-L6-v2"
74
  )
75
 
 
76
  similarity_k: int = int(os.getenv("SIMILARITY_K", "12"))
77
  top_k_final: int = int(os.getenv("TOP_K_FINAL", "4"))
78
  max_context_chars: int = int(os.getenv("MAX_CONTEXT_CHARS", "5200"))
79
 
 
80
  max_input_len: int = int(os.getenv("MAX_INPUT_LEN", "4096"))
81
  max_new_tokens_local: int = int(os.getenv("MAX_NEW_TOKENS_LOCAL", "180"))
82
  max_chat_history_turns: int = int(os.getenv("MAX_CHAT_HISTORY_TURNS", "6"))
83
 
 
84
  min_lexical_overlap: float = float(os.getenv("MIN_LEXICAL_OVERLAP", "0.08"))
85
  min_faiss_similarity: float = float(os.getenv("MIN_FAISS_SIMILARITY", "0.20"))
86
  strong_retrieval_threshold: float = float(os.getenv("STRONG_RETRIEVAL_THRESHOLD", "0.30"))
87
  strong_retrieval_min_docs: int = int(os.getenv("STRONG_RETRIEVAL_MIN_DOCS", "3"))
88
 
 
89
  use_query_cache: bool = os.getenv("USE_QUERY_CACHE", "true").lower() == "true"
90
  enable_query_expansion: bool = os.getenv("ENABLE_QUERY_EXPANSION", "true").lower() == "true"
91
  enable_validator: bool = os.getenv("ENABLE_VALIDATOR", "true").lower() == "true"
@@ -93,12 +102,15 @@ class Config:
93
  show_debug_panel: bool = os.getenv("SHOW_DEBUG_PANEL", "true").lower() == "true"
94
  allow_rebuild_vectorstore: bool = os.getenv("ALLOW_REBUILD_VECTORSTORE", "false").lower() == "true"
95
 
 
96
  use_4bit: bool = os.getenv("USE_4BIT", "true").lower() == "true"
97
 
 
98
  launch_debug: bool = os.getenv("LAUNCH_DEBUG", "false").lower() == "true"
99
  server_name: str = os.getenv("SERVER_NAME", "0.0.0.0")
100
  server_port: int = int(os.getenv("SERVER_PORT", "7860"))
101
 
 
102
  blink_stage_1: float = float(os.getenv("BLINK_STAGE_1", "0.40"))
103
  blink_stage_2: float = float(os.getenv("BLINK_STAGE_2", "0.55"))
104
  blink_stage_3: float = float(os.getenv("BLINK_STAGE_3", "0.50"))
 
22
  # ============================================================
23
  # HUGGING FACE SPACES READY
24
  # Medical CSV RAG Chatbot
25
+ # Optimized pipeline:
26
+ # RAG retrieval -> local ECG adapter reasoning -> grounded summary
27
+ # UI goal:
28
+ # polished mobile-friendly chatbot UX with minimal sources panel
29
  # ============================================================
30
 
31
  # -------------------------------
 
43
  # -------------------------------
44
  @dataclass
45
  class Config:
46
+ # Paths
47
  base_model_path: str = os.getenv(
48
  "BASE_MODEL_PATH",
49
  "meta-llama/Llama-3.1-8B-Instruct"
50
  )
 
51
  adapter_dir: str = os.getenv(
52
  "ADAPTER_DIR",
53
  "adapter_refined_v10"
 
62
  )
63
  vectorstore_dir: str = field(init=False)
64
 
65
+ # Auth / APIs
66
  hf_token: str = os.getenv("HF_TOKEN", "")
67
  deepseek_api_key: str = os.getenv("DEEPSEEK_API_KEY", "")
68
  deepseek_base_url: str = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
69
  deepseek_model: str = os.getenv("DEEPSEEK_MODEL", "deepseek-chat")
70
 
71
+ # DeepSeek generation
72
  deepseek_temperature: float = float(os.getenv("DEEPSEEK_TEMPERATURE", "0.1"))
73
  deepseek_max_tokens: int = int(os.getenv("DEEPSEEK_MAX_TOKENS", "550"))
74
 
75
+ # Embeddings
76
  embed_model_name: str = os.getenv(
77
  "EMBED_MODEL_NAME",
78
  "sentence-transformers/all-MiniLM-L6-v2"
79
  )
80
 
81
+ # Retrieval
82
  similarity_k: int = int(os.getenv("SIMILARITY_K", "12"))
83
  top_k_final: int = int(os.getenv("TOP_K_FINAL", "4"))
84
  max_context_chars: int = int(os.getenv("MAX_CONTEXT_CHARS", "5200"))
85
 
86
+ # Generation
87
  max_input_len: int = int(os.getenv("MAX_INPUT_LEN", "4096"))
88
  max_new_tokens_local: int = int(os.getenv("MAX_NEW_TOKENS_LOCAL", "180"))
89
  max_chat_history_turns: int = int(os.getenv("MAX_CHAT_HISTORY_TURNS", "6"))
90
 
91
+ # Filtering
92
  min_lexical_overlap: float = float(os.getenv("MIN_LEXICAL_OVERLAP", "0.08"))
93
  min_faiss_similarity: float = float(os.getenv("MIN_FAISS_SIMILARITY", "0.20"))
94
  strong_retrieval_threshold: float = float(os.getenv("STRONG_RETRIEVAL_THRESHOLD", "0.30"))
95
  strong_retrieval_min_docs: int = int(os.getenv("STRONG_RETRIEVAL_MIN_DOCS", "3"))
96
 
97
+ # Features
98
  use_query_cache: bool = os.getenv("USE_QUERY_CACHE", "true").lower() == "true"
99
  enable_query_expansion: bool = os.getenv("ENABLE_QUERY_EXPANSION", "true").lower() == "true"
100
  enable_validator: bool = os.getenv("ENABLE_VALIDATOR", "true").lower() == "true"
 
102
  show_debug_panel: bool = os.getenv("SHOW_DEBUG_PANEL", "true").lower() == "true"
103
  allow_rebuild_vectorstore: bool = os.getenv("ALLOW_REBUILD_VECTORSTORE", "false").lower() == "true"
104
 
105
+ # Model loading
106
  use_4bit: bool = os.getenv("USE_4BIT", "true").lower() == "true"
107
 
108
+ # Launch
109
  launch_debug: bool = os.getenv("LAUNCH_DEBUG", "false").lower() == "true"
110
  server_name: str = os.getenv("SERVER_NAME", "0.0.0.0")
111
  server_port: int = int(os.getenv("SERVER_PORT", "7860"))
112
 
113
+ # UI timings
114
  blink_stage_1: float = float(os.getenv("BLINK_STAGE_1", "0.40"))
115
  blink_stage_2: float = float(os.getenv("BLINK_STAGE_2", "0.55"))
116
  blink_stage_3: float = float(os.getenv("BLINK_STAGE_3", "0.50"))