Spaces:
Sleeping
Sleeping
| import os, sys, runpy, json, requests, time | |
| from fastapi import HTTPException | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.embeddings import SentenceTransformerEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| from langchain.schema import Document | |
| from langchain.chains.summarize import load_summarize_chain | |
| from langchain.llms.base import LLM | |
| from pathlib import Path | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # from bookli_logger import BookliLogger | |
| EMBED_MODEL = "all-MiniLM-L6-v2" | |
| DB_ROOT = os.environ.get("DB_ROOT") | |
| # -------------------- Custom API LLM Wrapper -------------------- | |
| class APILLM(LLM): | |
| api_url: str | |
| api_key: str | |
| model: str | |
| def _call(self, prompt, **kwargs): | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}", | |
| "Content-Type": "application/json" | |
| } | |
| data = { | |
| "model": self.model, | |
| "messages": [ | |
| {"role": "system", "content": "Follow user instructions carefully and answer concisely"}, | |
| {"role": "user", "content": prompt} | |
| ], | |
| "temperature": 0.3, | |
| "max_tokens": 1024 | |
| } | |
| res = requests.post(self.api_url, json=data, headers=headers) | |
| try: | |
| response_json = res.json() | |
| print("✅ API response:", response_json) | |
| message = ( | |
| response_json.get("choices", [{}])[0] | |
| .get("message", {}) | |
| .get("content", "") | |
| .strip() | |
| ) | |
| return message | |
| except Exception: | |
| print("❌ API response parse error:", res.text) | |
| return "Error parsing LLM response." | |
| def _llm_type(self): | |
| return "custom_api" | |
| # -------------------- LLM Selector -------------------- | |
| def get_llm(provider: str): | |
| if provider == "qwen-2.5-coder-32b-instruct": | |
| return APILLM( | |
| api_url="https://openrouter.ai/api/v1/chat/completions", | |
| api_key=os.getenv("OPENROUTER_KEY"), | |
| model="qwen/qwen-2.5-coder-32b-instruct:free" | |
| ) | |
| elif provider == "llama-3.3-70b-instruct:free": | |
| return APILLM( | |
| api_url="https://openrouter.ai/api/v1/chat/completions", | |
| api_key=os.getenv("OPENROUTER_KEY"), | |
| model="meta-llama/llama-3.3-70b-instruct:free" | |
| ) | |
| elif provider == "qwen-3-235b-a22b-instruct-2507": | |
| return APILLM( | |
| api_url="https://api.cerebras.ai/v1/chat/completions", | |
| api_key=os.getenv("CEREBRAS_KEY"), | |
| model="qwen-3-235b-a22b-instruct-2507" | |
| ) | |
| elif provider == "gpt-oss-120b": | |
| return APILLM( | |
| api_url="https://api.cerebras.ai/v1/chat/completions", | |
| api_key=os.getenv("CEREBRAS_KEY"), | |
| model="gpt-oss-120b" | |
| ) | |
| else: | |
| raise HTTPException(status_code=400, detail="Unsupported provider") | |
| # -------------------- Vector Store Loader -------------------- | |
| def get_vectorstore(job_id: str): | |
| db_path = os.path.join(DB_ROOT, job_id) | |
| if not os.path.exists(db_path): | |
| raise HTTPException(status_code=404, detail=f"Job '{job_id}' not found or not indexed yet.") | |
| embeddings = SentenceTransformerEmbeddings(model_name=EMBED_MODEL) | |
| return Chroma(persist_directory=db_path, embedding_function=embeddings) | |
| def get_metadata(job_id: str): | |
| """ | |
| Loads metadata.json from ./db/{job_id}/temp/metadata.json, | |
| flattens it into key:value lines for the LLM prompt, | |
| and returns both (dict, flattened string). | |
| """ | |
| metadata_path = os.path.join(os.environ.get("JOB_ROOT"), job_id, "temp", "metadata.json") | |
| if not os.path.exists(metadata_path): | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Metadata not found for job '{job_id}' at {metadata_path}" | |
| ) | |
| with open(metadata_path, "r", encoding="utf-8") as f: | |
| metadata_dict = json.load(f) | |
| # Flatten metadata for LLM prompt | |
| metadata_text = format_metadata(metadata_dict) | |
| return metadata_dict, metadata_text | |
| def format_metadata(metadata: dict, prefix: str = "") -> str: | |
| """ | |
| Recursively converts nested dicts/lists into a flattened key:value text block. | |
| Example: | |
| { | |
| "title": "Book", | |
| "chapters": [ | |
| {"title": "Intro", "page": 1}, | |
| {"title": "End", "page": 10} | |
| ] | |
| } | |
| → | |
| title: Book | |
| chapters.1.title: Intro | |
| chapters.1.page: 1 | |
| chapters.2.title: End | |
| chapters.2.page: 10 | |
| """ | |
| lines = [] | |
| for key, value in metadata.items(): | |
| path = f"{prefix}{key}" if prefix == "" else f"{prefix}.{key}" | |
| if isinstance(value, dict): | |
| # Recurse for nested dict | |
| lines.append(format_metadata(value, prefix=path)) | |
| elif isinstance(value, list): | |
| if all(isinstance(v, (str, int, float, bool)) for v in value): | |
| # Simple list of primitives | |
| lines.append(f"{path}: {', '.join(map(str, value))}") | |
| else: | |
| # List of objects | |
| for i, item in enumerate(value, 1): | |
| if isinstance(item, dict): | |
| lines.append(format_metadata(item, prefix=f"{path}.{i}")) | |
| else: | |
| lines.append(f"{path}.{i}: {item}") | |
| else: | |
| # Base case — key:value | |
| lines.append(f"{path}: {value}") | |
| return "\n".join(lines) | |
| def load_system_prompt(path: str = "./system_prompts/active_system_prompt") -> str: | |
| """ | |
| Loads the system prompt text from the specified file. | |
| Args: | |
| path (str): Path to the system prompt file. | |
| Returns: | |
| str: The full prompt text. | |
| Raises: | |
| FileNotFoundError: If the file does not exist. | |
| UnicodeDecodeError: If the file cannot be decoded as UTF-8. | |
| """ | |
| prompt_path = Path(path) | |
| if not prompt_path.exists(): | |
| raise FileNotFoundError(f"System prompt file not found: {prompt_path.resolve()}") | |
| # Always open with UTF-8 encoding for safety | |
| with prompt_path.open("r", encoding="utf-8") as f: | |
| prompt_text = f.read().strip() | |
| return prompt_text | |
| def summarize_for_reasoning(llm, retriever, q, metadata_text): | |
| """Adaptive summarization: chooses 'stuff' or 'map_reduce' based on doc count.""" | |
| docs = retriever.get_relevant_documents(q) | |
| doc_count = len(docs) | |
| # --- Small context: single call (fast, quota-safe) --- | |
| if doc_count <= 5: | |
| summarize_chain = load_summarize_chain(llm, chain_type="stuff") | |
| summary_text = summarize_chain.run(docs) | |
| print(f"🟢 Using 'stuff' summarization (docs={doc_count})") | |
| # --- Larger context: multi-step summarize with throttle --- | |
| else: | |
| summarize_chain = load_summarize_chain(llm, chain_type="map_reduce") | |
| print(f"🟠 Using 'map_reduce' summarization (docs={doc_count})") | |
| try: | |
| # Throttle requests to avoid "too many requests" errors | |
| time.sleep(1.0) | |
| summary_text = summarize_chain.run(docs) | |
| except Exception as e: | |
| # Retry once with longer delay if throttled | |
| print(f"⚠️ Summarization error ({e}), retrying...") | |
| time.sleep(3) | |
| summary_text = summarize_chain.run(docs) | |
| # Wrap as Document for combine_docs_chain | |
| summary_doc = Document(page_content=f"Story Summary:\n{summary_text}\n\n{metadata_text}") | |
| return [summary_doc] |