Rajan Sharma commited on
Commit
44a97e9
·
verified ·
1 Parent(s): 8aebe10

Update llm_router.py

Browse files
Files changed (1) hide show
  1. llm_router.py +32 -107
llm_router.py CHANGED
@@ -1,109 +1,34 @@
1
- from typing import Optional, List
2
- import time
3
- import cohere
4
- from settings import (
5
- COHERE_API_KEY, COHERE_API_URL, COHERE_MODEL_PRIMARY, COHERE_EMBED_MODEL,
6
- MODEL_SETTINGS, USE_OPEN_FALLBACKS
7
- )
8
-
9
- # Optional open-model fallback (only used if USE_OPEN_FALLBACKS=True)
10
- try:
11
- from local_llm import LocalLLM
12
- _HAS_LOCAL = True
13
- except Exception:
14
- _HAS_LOCAL = False
15
-
16
- _client: Optional[cohere.Client] = None
17
-
18
- def _co_client() -> Optional[cohere.Client]:
19
- global _client
20
- if _client is not None:
21
- return _client
22
- if not COHERE_API_KEY:
23
- return None
24
- # NOTE: The Cohere Python SDK auto-selects API base; you can pass a custom base if provided.
25
- if COHERE_API_URL:
26
- _client = cohere.Client(api_key=COHERE_API_KEY, base_url=COHERE_API_URL, timeout=MODEL_SETTINGS.get("timeout_s", 45))
27
- else:
28
- _client = cohere.Client(api_key=COHERE_API_KEY, timeout=MODEL_SETTINGS.get("timeout_s", 45))
29
- return _client
30
-
31
- def _retry(fn, attempts=3, backoff=0.8):
32
- last = None
33
- for i in range(attempts):
34
- try:
35
- return fn()
36
- except Exception as e:
37
- last = e
38
- time.sleep(backoff * (2 ** i))
39
- raise last if last else RuntimeError("Unknown error")
40
-
41
- def cohere_chat(prompt: str) -> Optional[str]:
42
- cli = _co_client()
43
- if not cli:
44
- return None
45
- def _call():
46
- resp = cli.chat(
47
- model=COHERE_MODEL_PRIMARY,
48
- message=prompt,
49
- temperature=MODEL_SETTINGS["temperature"],
50
- max_tokens=MODEL_SETTINGS["max_new_tokens"],
51
  )
52
- # SDK shape may provide .text, .reply, or generations
53
- if hasattr(resp, "text") and resp.text:
54
- return resp.text
55
- if hasattr(resp, "reply") and resp.reply:
56
- return resp.reply
57
- if hasattr(resp, "generations") and resp.generations:
58
- return resp.generations[0].text
59
- return None
60
- try:
61
- return _retry(_call, attempts=3)
62
- except Exception:
63
- return None
64
-
65
- def open_fallback_chat(prompt: str) -> Optional[str]:
66
- if not USE_OPEN_FALLBACKS or not _HAS_LOCAL:
67
- return None
68
- try:
69
- return LocalLLM().chat(prompt)
70
- except Exception:
71
- return None
72
-
73
- def cohere_embed(texts: List[str]) -> List[List[float]]:
74
- cli = _co_client()
75
- if not cli or not texts:
76
- return []
77
- def _call():
78
- resp = cli.embed(texts=texts, model=COHERE_EMBED_MODEL)
79
- # Newer SDK: resp.embeddings; older: resp.data
80
- return getattr(resp, "embeddings", None) or getattr(resp, "data", []) or []
81
- try:
82
- return _retry(_call, attempts=3)
83
- except Exception:
84
- return []
85
-
86
- def generate_narrative(scenario_text: str, structured_sections_md: str, rag_snippets: List[str]) -> str:
87
- grounding = "\n\n".join([f"[RAG {i+1}]\n{t}" for i, t in enumerate(rag_snippets or [])])
88
- prompt = f"""You are a Canadian healthcare operations copilot.
89
- Follow the scenario's requested deliverables exactly. Use the structured computations provided (already calculated deterministically) and the RAG snippets for grounding.
90
-
91
- # Scenario
92
- {scenario_text}
93
-
94
- # Deterministic Results (already computed)
95
- {structured_sections_md}
96
-
97
- # Grounding (Canadian sources, snippets)
98
- {grounding}
99
-
100
- Write a concise, decision-ready report tailored to provincial operations leaders.
101
- Do not invent numbers. If data are missing, say so clearly.
102
- """
103
- out = cohere_chat(prompt)
104
- if out: return out
105
- out = open_fallback_chat(prompt)
106
- if out: return out
107
- return "Unable to generate narrative at this time."
108
-
109
 
 
1
+ from typing import Optional
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
+ from settings import OPEN_LLM_CANDIDATES, LOCAL_MAX_NEW_TOKENS
5
+
6
+ class LocalLLM:
7
+ def __init__(self):
8
+ self.pipe = None
9
+ self._load_any()
10
+
11
+ def _load_any(self):
12
+ for mid in OPEN_LLM_CANDIDATES:
13
+ try:
14
+ tok = AutoTokenizer.from_pretrained(mid, trust_remote_code=True)
15
+ mdl = AutoModelForCausalLM.from_pretrained(
16
+ mid, device_map="auto",
17
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
18
+ trust_remote_code=True
19
+ )
20
+ self.pipe = pipeline("text-generation", model=mdl, tokenizer=tok)
21
+ return
22
+ except Exception:
23
+ continue
24
+
25
+ def chat(self, prompt: str) -> Optional[str]:
26
+ if not self.pipe: return None
27
+ out = self.pipe(
28
+ prompt, max_new_tokens=LOCAL_MAX_NEW_TOKENS,
29
+ do_sample=True, temperature=0.3, top_p=0.9, repetition_penalty=1.12,
30
+ eos_token_id=self.pipe.tokenizer.eos_token_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  )
32
+ text = out[0]["generated_text"]
33
+ return text[len(prompt):].strip() if text.startswith(prompt) else text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34