Spaces:
Sleeping
Sleeping
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.llms import HuggingFacePipeline | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, pipeline | |
| from langchain.prompts import PromptTemplate | |
| from langchain_core.runnables import RunnablePassthrough | |
| from langchain_core.output_parsers import StrOutputParser | |
| import traceback | |
| import re | |
| import os | |
| from huggingface_hub import login | |
| token = os.getenv("HF_TOKEN") | |
| print("π HF_TOKEN available?", token is not None) | |
| if token: | |
| login(token=token) | |
| else: | |
| print("β No HF_TOKEN found in environment") | |
| def build_qa(): | |
| """Builds and returns the RAG QA pipeline (rag_chain style).""" | |
| print("π Starting QA pipeline...") | |
| # 1. Embeddings | |
| print("πΉ Loading embeddings...") | |
| embeddings = HuggingFaceEmbeddings( | |
| model_name="sentence-transformers/all-MiniLM-L6-v2" | |
| ) | |
| # 2. Load vector DB | |
| print("πΉ Loading Chroma DB...") | |
| vectorstore = Chroma( | |
| persist_directory="db", | |
| collection_name="rag-docs", | |
| embedding_function=embeddings, | |
| ) | |
| print("π Docs in DB:", vectorstore._collection.count()) | |
| # 3. Load LLM (Phi-3 mini) | |
| print("πΉ Loading LLM...") | |
| model_id = "meta-llama/Llama-3.2-1B-Instruct" # or "meta-llama/Llama-3.1-1B-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| trust_remote_code=True # ensures it runs on available CPU | |
| ) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=128, | |
| temperature=0.4, # keeps answers deterministic but less rigid than 0 | |
| do_sample=True, # allow some randomness | |
| top_p=0.9, # nucleus sampling to avoid loops | |
| repetition_penalty=1.2, # π penalize repeats | |
| eos_token_id=tokenizer.eos_token_id, # stop at EOS | |
| return_full_text=False | |
| ) | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| # 4. Retriever | |
| retriever = vectorstore.as_retriever(search_kwargs={"k": 3}) | |
| prompt = PromptTemplate( | |
| input_variables=["context", "question"], | |
| template=""" | |
| Use the following context to answer the question. | |
| - Answer from the docs | |
| - Answer in plain natural language. | |
| - Do not include code, imports, functions, or explanations of how to implement code. | |
| - If you don't know, just say "I don't know." | |
| Context: | |
| {context} | |
| Question: {question} | |
| Answer (one short sentence): | |
| """, | |
| ) | |
| # 6. Helper functions | |
| def format_docs(docs): | |
| return "\n".join(doc.page_content for doc in docs) | |
| def hf_to_str(x): | |
| """Convert Hugging Face pipeline output to clean plain text.""" | |
| if isinstance(x, list) and "generated_text" in x[0]: | |
| text = x[0]["generated_text"] | |
| else: | |
| text = str(x) | |
| # Remove code-like patterns (imports, defs, classes, etc.) | |
| text = re.sub(r"(from\s+\w+\s+import\s+.*|import\s+\w+.*)", "", text) | |
| text = re.sub(r"def\s+\w+\(.*?\):.*", "", text, flags=re.DOTALL) | |
| text = re.sub(r"class\s+\w+.*?:.*", "", text, flags=re.DOTALL) | |
| text = re.sub(r"text\s*\+=.*", "", text) | |
| # Remove markdown/code fences & quotes | |
| text = text.replace("```", "").replace("'''", "").replace('"""', "").replace("\\n", " ") | |
| # Normalize whitespace | |
| text = re.sub(r"\s+", " ", text) | |
| # Deduplicate repeated sentences | |
| sentences = [] | |
| for s in re.split(r"(?<=[.!?])\s+", text): | |
| if s and s not in sentences: | |
| sentences.append(s) | |
| text = " ".join(sentences) | |
| return text.strip() | |
| # 7. RAG chain | |
| rag_chain = ( | |
| { | |
| "context": retriever | format_docs, | |
| "question": RunnablePassthrough(), | |
| } | |
| | prompt | |
| | (lambda x: str(x)) # convert PromptTemplate value to str | |
| | llm | |
| | (lambda x: hf_to_str(x)) # clean HF output | |
| | StrOutputParser() | |
| ) | |
| print("β QA pipeline ready.") | |
| return rag_chain | |
| # Build once | |
| try: | |
| qa_pipeline = build_qa() | |
| print("β qa_pipeline built successfully:", type(qa_pipeline)) | |
| except Exception as e: | |
| qa_pipeline = None | |
| print("β Failed to build QA pipeline") | |
| print("Error message:", str(e)) | |
| traceback.print_exc() | |
| def get_answer(query: str) -> str: | |
| """ | |
| Run a query against the QA pipeline and return the answer text. | |
| """ | |
| if qa_pipeline is None: | |
| return "β οΈ QA pipeline not initialized." | |
| try: | |
| result = qa_pipeline.invoke(query) # for LCEL chain | |
| return result | |
| except Exception as e: | |
| return f"β QA run failed: {e}" | |