mlbench123 commited on
Commit
a89f340
·
verified ·
1 Parent(s): 70950ed

Upload 6 files

Browse files
Files changed (6) hide show
  1. app.py +30 -0
  2. database.xlsx +0 -0
  3. llm_client.py +198 -0
  4. requirements.txt +12 -0
  5. treatment_embeddings.pkl +3 -0
  6. web_retriever.py +223 -0
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Hugging Face Spaces entrypoint.
4
+
5
+ HF Spaces looks for either:
6
+ - app.py with a variable named `demo` or `app`, OR
7
+ - a Gradio `Blocks` returned and launched.
8
+
9
+ This file reuses your existing Gradio UI factory.
10
+ """
11
+
12
+ import os
13
+
14
+ # Optional: you can set defaults for HF here
15
+ os.environ.setdefault("DB_XLSX", "database.xlsx")
16
+ os.environ.setdefault("EMB_CACHE", "treatment_embeddings.pkl")
17
+
18
+ # IMPORTANT: in HF we do NOT have Ollama. Use transformers backend.
19
+ os.environ.setdefault("LOCAL_LLM_PROVIDER", "transformers")
20
+
21
+ # Choose a CPU-friendly open model (no auth required).
22
+ # Good default: TinyLlama (fast-ish on CPU).
23
+ os.environ.setdefault("HF_LLM_MODEL", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")
24
+
25
+ from gradio_new_rag_app import make_app
26
+
27
+ demo = make_app()
28
+
29
+ if __name__ == "__main__":
30
+ demo.launch()
database.xlsx ADDED
Binary file (41.4 kB). View file
 
llm_client.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Local LLM client abstraction (NO OpenAI/Claude).
4
+
5
+ Providers:
6
+ - ollama : calls a local Ollama server (your Windows dev)
7
+ - transformers : runs a local HF model in-process (best for Hugging Face Spaces CPU)
8
+
9
+ Env:
10
+ LOCAL_LLM_PROVIDER=ollama|transformers
11
+
12
+ Ollama:
13
+ OLLAMA_HOST=http://localhost:11434
14
+ OLLAMA_MODEL=llama3.2:1b
15
+
16
+ Transformers:
17
+ HF_LLM_MODEL=TinyLlama/TinyLlama-1.1B-Chat-v1.0 (recommended CPU default)
18
+ HF_MAX_NEW_TOKENS=450
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import json
24
+ import os
25
+ import re
26
+ from typing import Any, Dict, Optional
27
+
28
+ import requests
29
+
30
+
31
+ class LocalLLMClient:
32
+ def __init__(
33
+ self,
34
+ provider: Optional[str] = None,
35
+ model: Optional[str] = None,
36
+ host: Optional[str] = None,
37
+ timeout_sec: int = 120,
38
+ ):
39
+ self.provider = (provider or os.getenv("LOCAL_LLM_PROVIDER", "ollama")).lower().strip()
40
+ self.timeout_sec = int(os.getenv("LLM_TIMEOUT_SEC", str(timeout_sec)))
41
+
42
+ # Ollama settings
43
+ self.host = (host or os.getenv("OLLAMA_HOST", "http://localhost:11434")).strip()
44
+ self.model = (model or os.getenv("OLLAMA_MODEL", "llama3.2:1b")).strip()
45
+
46
+ # Transformers settings (HF Spaces)
47
+ self.hf_model_id = (os.getenv("HF_LLM_MODEL", "TinyLlama/TinyLlama-1.1B-Chat-v1.0")).strip()
48
+ self.hf_max_new_tokens = int(os.getenv("HF_MAX_NEW_TOKENS", "450"))
49
+
50
+ self._hf_pipe = None # lazy init
51
+
52
+ if self.provider not in {"ollama", "transformers"}:
53
+ raise ValueError(
54
+ f"Unsupported LOCAL_LLM_PROVIDER='{self.provider}'. "
55
+ "Supported: ollama, transformers."
56
+ )
57
+
58
+ # --------------------------- Public API ---------------------------
59
+
60
+ def generate(self, prompt: str, temperature: float = 0.2, max_tokens: int = 900) -> str:
61
+ prompt = (prompt or "").strip()
62
+ if not prompt:
63
+ return ""
64
+
65
+ if self.provider == "ollama":
66
+ return self._generate_ollama(prompt, temperature=temperature, max_tokens=max_tokens)
67
+
68
+ # transformers
69
+ return self._generate_transformers(prompt, temperature=temperature, max_tokens=max_tokens)
70
+
71
+ # --------------------------- Ollama ---------------------------
72
+
73
+ def _generate_ollama(self, prompt: str, temperature: float, max_tokens: int) -> str:
74
+ url = self.host.rstrip("/") + "/api/generate"
75
+ payload: Dict[str, Any] = {
76
+ "model": self.model,
77
+ "prompt": prompt,
78
+ "stream": False,
79
+ "options": {
80
+ "temperature": float(temperature),
81
+ "num_predict": int(max_tokens),
82
+ },
83
+ }
84
+
85
+ try:
86
+ r = requests.post(url, json=payload, timeout=self.timeout_sec)
87
+ except requests.RequestException as e:
88
+ raise RuntimeError(
89
+ "Failed to connect to local Ollama.\n"
90
+ f"Tried: {url}\n"
91
+ "Fix:\n"
92
+ " - Ensure Ollama is running\n"
93
+ " - Confirm endpoint: iwr http://localhost:11434/api/tags -UseBasicParsing\n"
94
+ f"Error: {repr(e)}"
95
+ ) from e
96
+
97
+ if r.status_code != 200:
98
+ body = (r.text or "").strip()
99
+ msg = body
100
+ try:
101
+ j = r.json()
102
+ if isinstance(j, dict):
103
+ msg = j.get("error") or j.get("message") or body
104
+ except Exception:
105
+ pass
106
+ raise RuntimeError(
107
+ "Ollama returned an error.\n"
108
+ f"URL: {url}\n"
109
+ f"HTTP: {r.status_code}\n"
110
+ f"Model: {self.model}\n"
111
+ f"Details: {msg}"
112
+ )
113
+
114
+ data = r.json()
115
+ return (data.get("response") or "").strip()
116
+
117
+ # --------------------------- Transformers (HF Spaces) ---------------------------
118
+
119
+ def _lazy_init_hf(self):
120
+ if self._hf_pipe is not None:
121
+ return
122
+
123
+ # Lazy import to keep local installs lighter
124
+ from transformers import pipeline
125
+
126
+ # CPU inference; use bfloat16 only if supported (some spaces may not)
127
+ # Keep it simple and robust.
128
+ self._hf_pipe = pipeline(
129
+ "text-generation",
130
+ model=self.hf_model_id,
131
+ device=-1, # CPU
132
+ )
133
+
134
+ def _generate_transformers(self, prompt: str, temperature: float, max_tokens: int) -> str:
135
+ self._lazy_init_hf()
136
+
137
+ # Cap generation for HF CPU
138
+ max_new = min(int(max_tokens), int(self.hf_max_new_tokens))
139
+
140
+ # Many instruct/chat models work better with a simple instruction wrapper.
141
+ wrapped = (
142
+ "You are a helpful assistant.\n\n"
143
+ f"{prompt}\n\n"
144
+ "Answer:"
145
+ )
146
+
147
+ out = self._hf_pipe(
148
+ wrapped,
149
+ max_new_tokens=max_new,
150
+ do_sample=True,
151
+ temperature=float(max(0.05, temperature)),
152
+ top_p=0.9,
153
+ repetition_penalty=1.1,
154
+ )
155
+
156
+ if not out:
157
+ return ""
158
+
159
+ # pipeline returns list[{"generated_text": "..."}]
160
+ text = out[0].get("generated_text", "")
161
+ text = (text or "").strip()
162
+
163
+ # Remove the prompt prefix if the model echoed it
164
+ if text.startswith(wrapped):
165
+ text = text[len(wrapped):].strip()
166
+
167
+ return text
168
+
169
+ # --------------------------- JSON helpers ---------------------------
170
+
171
+ @staticmethod
172
+ def _strip_code_fences(text: str) -> str:
173
+ t = text.strip()
174
+ t = re.sub(r"^```(?:json)?\s*", "", t, flags=re.IGNORECASE)
175
+ t = re.sub(r"\s*```$", "", t)
176
+ return t.strip()
177
+
178
+ def safe_json_loads(self, text: str) -> Dict[str, Any]:
179
+ if not text:
180
+ return {}
181
+
182
+ t = self._strip_code_fences(text)
183
+
184
+ try:
185
+ out = json.loads(t)
186
+ return out if isinstance(out, dict) else {}
187
+ except Exception:
188
+ pass
189
+
190
+ m = re.search(r"\{.*\}", t, flags=re.DOTALL)
191
+ if m:
192
+ try:
193
+ out = json.loads(m.group(0))
194
+ return out if isinstance(out, dict) else {}
195
+ except Exception:
196
+ return {}
197
+
198
+ return {}
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ pandas
3
+ numpy
4
+ openpyxl
5
+ scikit-learn
6
+ sentence-transformers
7
+ torch
8
+ transformers
9
+ accelerate
10
+ requests
11
+ beautifulsoup4
12
+ lxml
treatment_embeddings.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a91ab8b6879a80ecae1d39d1b36fda5b947db9a52b7ce0c651a55068d4f0cce
3
+ size 1745225
web_retriever.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ WebRetriever: lightweight, keyless web search + fetch for local CPU RAG / HF Spaces.
4
+
5
+ - Search: DuckDuckGo HTML endpoint (no API key)
6
+ - Fetch: requests + BeautifulSoup
7
+ - Extract: visible text + quick snippet, capped to keep prompts small
8
+
9
+ UPDATED FOR HF / PUBLIC TESTING:
10
+ - Graceful failure: never crash app when network blocks / 403 / 429 / timeouts occur
11
+ - Basic retries with backoff
12
+ - Canonicalize DuckDuckGo redirect URLs (uddg)
13
+ - Better HTML cleanup and snippet construction
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import random
19
+ import re
20
+ import time
21
+ from dataclasses import dataclass
22
+ from typing import List, Optional, Tuple
23
+ from urllib.parse import quote_plus, urlparse, parse_qs, unquote
24
+
25
+ import requests
26
+ from bs4 import BeautifulSoup
27
+
28
+
29
+ @dataclass
30
+ class WebDoc:
31
+ title: str
32
+ url: str
33
+ snippet: str
34
+
35
+
36
+ class WebRetriever:
37
+ def __init__(
38
+ self,
39
+ user_agent: Optional[str] = None,
40
+ timeout_sec: int = 15,
41
+ polite_delay_sec: float = 0.4,
42
+ max_retries: int = 2,
43
+ backoff_base_sec: float = 0.8,
44
+ ):
45
+ # Use a plausible UA; HF outbound can be sensitive to "bot" UAs.
46
+ self.user_agent = user_agent or (
47
+ "Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
48
+ "AppleWebKit/537.36 (KHTML, like Gecko) "
49
+ "Chrome/120.0.0.0 Safari/537.36"
50
+ )
51
+ self.timeout_sec = timeout_sec
52
+ self.polite_delay_sec = polite_delay_sec
53
+ self.max_retries = max_retries
54
+ self.backoff_base_sec = backoff_base_sec
55
+
56
+ # ------------------------------------------------------------------
57
+ # Internal: request with retries/backoff
58
+ # ------------------------------------------------------------------
59
+ def _request(self, method: str, url: str, **kwargs) -> Optional[requests.Response]:
60
+ headers = kwargs.pop("headers", {})
61
+ headers.setdefault("User-Agent", self.user_agent)
62
+ kwargs["headers"] = headers
63
+ kwargs.setdefault("timeout", self.timeout_sec)
64
+
65
+ for attempt in range(self.max_retries + 1):
66
+ try:
67
+ resp = requests.request(method, url, **kwargs)
68
+
69
+ # Some sites rate-limit aggressively; treat 429/403 as "soft fail"
70
+ if resp.status_code in (403, 429):
71
+ # Backoff and retry; may still fail; eventually return None
72
+ self._sleep_backoff(attempt)
73
+ continue
74
+
75
+ resp.raise_for_status()
76
+ return resp
77
+
78
+ except Exception:
79
+ # Backoff then retry; if last attempt, return None
80
+ if attempt >= self.max_retries:
81
+ return None
82
+ self._sleep_backoff(attempt)
83
+
84
+ return None
85
+
86
+ def _sleep_backoff(self, attempt: int) -> None:
87
+ # Exponential backoff with jitter
88
+ base = self.backoff_base_sec * (2 ** attempt)
89
+ jitter = random.uniform(0.0, 0.25)
90
+ time.sleep(min(6.0, base + jitter))
91
+
92
+ # ------------------------------------------------------------------
93
+ # URL cleaning: unwrap DuckDuckGo redirect links
94
+ # ------------------------------------------------------------------
95
+ @staticmethod
96
+ def _unwrap_ddg_redirect(url: str) -> str:
97
+ try:
98
+ p = urlparse(url)
99
+ # Example: https://duckduckgo.com/l/?uddg=<encoded_url>
100
+ if "duckduckgo.com" in p.netloc.lower() and p.path.startswith("/l/"):
101
+ qs = parse_qs(p.query)
102
+ uddg = qs.get("uddg", [""])[0]
103
+ if uddg:
104
+ return unquote(uddg)
105
+ except Exception:
106
+ pass
107
+ return url
108
+
109
+ @staticmethod
110
+ def _dedupe_key(url: str) -> str:
111
+ try:
112
+ p = urlparse(url)
113
+ netloc = (p.netloc or "").lower()
114
+ path = (p.path or "").lower()
115
+ # Drop fragments and most query params for dedupe
116
+ return f"{netloc}{path}"
117
+ except Exception:
118
+ return url
119
+
120
+ # ------------------------------------------------------------------
121
+ # Search using DuckDuckGo HTML
122
+ # ------------------------------------------------------------------
123
+ def search(self, query: str, max_results: int = 5) -> List[WebDoc]:
124
+ q = (query or "").strip()
125
+ if not q:
126
+ return []
127
+
128
+ url = f"https://duckduckgo.com/html/?q={quote_plus(q)}"
129
+
130
+ resp = self._request("GET", url)
131
+ if resp is None:
132
+ return []
133
+
134
+ soup = BeautifulSoup(resp.text, "html.parser")
135
+ results: List[WebDoc] = []
136
+
137
+ # DDG HTML results usually contain: a.result__a
138
+ for a in soup.select("a.result__a")[: max_results * 3]:
139
+ title = a.get_text(" ", strip=True)
140
+ href = a.get("href") or ""
141
+ if not href:
142
+ continue
143
+
144
+ href = self._unwrap_ddg_redirect(href)
145
+ results.append(WebDoc(title=title, url=href, snippet=""))
146
+
147
+ if len(results) >= max_results:
148
+ break
149
+
150
+ # Polite delay to reduce rate limiting
151
+ time.sleep(self.polite_delay_sec)
152
+ return results
153
+
154
+ # ------------------------------------------------------------------
155
+ # Fetch and extract snippet
156
+ # ------------------------------------------------------------------
157
+ def fetch_snippet(self, url: str, max_chars: int = 900) -> str:
158
+ url = (url or "").strip()
159
+ if not url:
160
+ return ""
161
+
162
+ resp = self._request("GET", url)
163
+ if resp is None:
164
+ return ""
165
+
166
+ soup = BeautifulSoup(resp.text, "html.parser")
167
+
168
+ # Remove scripts/styles/nav/common clutter
169
+ for tag in soup(["script", "style", "noscript", "header", "footer", "nav", "aside", "form", "svg"]):
170
+ try:
171
+ tag.decompose()
172
+ except Exception:
173
+ pass
174
+
175
+ # Prefer main/article if available
176
+ main = soup.find("main")
177
+ article = soup.find("article")
178
+ root = article or main or soup.body or soup
179
+
180
+ text = root.get_text(" ", strip=True)
181
+ text = re.sub(r"\s+", " ", text).strip()
182
+
183
+ if not text:
184
+ return ""
185
+
186
+ if len(text) > max_chars:
187
+ text = text[:max_chars].rsplit(" ", 1)[0] + "…"
188
+
189
+ time.sleep(self.polite_delay_sec)
190
+ return text
191
+
192
+ # ------------------------------------------------------------------
193
+ # Combined: multiple queries -> docs
194
+ # ------------------------------------------------------------------
195
+ def search_and_fetch(
196
+ self,
197
+ queries: List[str],
198
+ max_results_per_query: int = 3,
199
+ max_docs: int = 6,
200
+ max_chars_per_doc: int = 900,
201
+ ) -> List[WebDoc]:
202
+ docs: List[WebDoc] = []
203
+ seen = set()
204
+
205
+ for q in queries:
206
+ results = self.search(q, max_results=max_results_per_query)
207
+ if not results:
208
+ continue
209
+
210
+ for res in results:
211
+ url = self._unwrap_ddg_redirect(res.url)
212
+ key = self._dedupe_key(url)
213
+ if key in seen:
214
+ continue
215
+ seen.add(key)
216
+
217
+ snippet = self.fetch_snippet(url, max_chars=max_chars_per_doc)
218
+ docs.append(WebDoc(title=res.title, url=url, snippet=snippet))
219
+
220
+ if len(docs) >= max_docs:
221
+ return docs
222
+
223
+ return docs