Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| from langchain.prompts import PromptTemplate | |
| from langchain_groq import ChatGroq | |
| from typing import Literal | |
| # Load environment variables | |
| load_dotenv() | |
| # Initialize LLMs | |
| def initialize_llms(): | |
| """Initialize and return the LLM instances""" | |
| groq_api_key = os.getenv("GROQ_API_KEY") | |
| return { | |
| "rewrite_llm": ChatGroq( | |
| temperature=0.1, | |
| model="llama-3.3-70b-versatile", | |
| api_key=groq_api_key | |
| ), | |
| "step_back_llm": ChatGroq( | |
| temperature=0, | |
| model="Gemma2-9B-IT", | |
| api_key=groq_api_key | |
| ) | |
| } | |
| # Certification classification | |
| def classify_certification( | |
| query: str, | |
| llm: ChatGroq, | |
| certs_dir: str = "docs/processed" | |
| ) -> str: | |
| """ | |
| Classify which certification a query is referring to. | |
| Returns certification name or 'no certification mentioned'. | |
| """ | |
| available_certs = "2BSvs, CertifHy - National Green Certificate (NGC), CertifHy - RFNBO, Certified_Hydrogen_Producer, GH2_Standard, Green_Hydrogen_Certification, ISCC CORSIA, ISCC EU (International Sustainability & Carbon Certification), ISCC PLUS, ISO_19880_Hydrogen_Quality, REDcert-EU, RSB, Scottish Quality Farm Assured Combinable Crops (SQC), TUV Rheinland H2.21, UK RTFO_regulation" | |
| template = """ | |
| You are an AI assistant classifying user queries based on the certification they are asking for in a RAG system. | |
| Classify the given query into one of the following certifications: | |
| - {available_certifications} | |
| Don't need any explanation, just return the name of the certification. | |
| Use the exact name of the certification as it appears in the directory. | |
| If the query refers to multiple certifications, return the most relevant one. | |
| If the query doesn't mention any certification, respond with "no certification mentioned". | |
| Original query: {original_query} | |
| Classification: | |
| """ | |
| prompt = PromptTemplate( | |
| input_variables=["original_query", "available_certifications"], | |
| template=template | |
| ) | |
| chain = prompt | llm | |
| response = chain.invoke({ | |
| "original_query": query, | |
| "available_certifications": available_certs | |
| }).content.strip() | |
| return response | |
| # Query specificity classification | |
| def classify_query_specificity( | |
| query: str, | |
| llm: ChatGroq | |
| ) -> Literal["specific", "general", "too narrow"]: | |
| """ | |
| Classify query specificity. | |
| Returns one of: 'specific', 'general', or 'too narrow'. | |
| """ | |
| template = """ | |
| You are an AI assistant classifying user queries based on their specificity for a RAG system. | |
| Classify the given query into one of: | |
| - "specific" → If it asks for exact values, certifications, or well-defined facts. | |
| - "general" → If it is broad and needs refinement for better retrieval. | |
| - "too narrow" → If it is very specific and might need broader context. | |
| DO NOT output explanations, only return one of: "specific", "general", or "too narrow". | |
| Original query: {original_query} | |
| Classification: | |
| """ | |
| prompt = PromptTemplate( | |
| input_variables=["original_query"], | |
| template=template | |
| ) | |
| chain = prompt | llm | |
| response = chain.invoke({"original_query": query}).content.strip().lower() | |
| return response.split("\n")[0].strip() # type: ignore | |
| # Query refinement | |
| def refine_query( | |
| query: str, | |
| llm: ChatGroq | |
| ) -> str: | |
| """Rewrite a query to be clearer and more detailed while keeping the original intent""" | |
| template = """ | |
| You are an AI assistant that improves queries for retrieving precise certification and compliance data. | |
| Rewrite the query to be clearer while keeping the intent unchanged. | |
| Original query: {original_query} | |
| Refined query: | |
| """ | |
| prompt = PromptTemplate( | |
| input_variables=["original_query"], | |
| template=template | |
| ) | |
| chain = prompt | llm | |
| return chain.invoke({"original_query": query}).content | |
| # Step-back query generation | |
| def generate_step_back_query( | |
| query: str, | |
| llm: ChatGroq | |
| ) -> str: | |
| """Generate a broader step-back query to retrieve relevant background information""" | |
| template = """ | |
| You are an AI assistant generating broader queries to improve retrieval context. | |
| Given the original query, generate a more general step-back query to retrieve relevant background information. | |
| Original query: {original_query} | |
| Step-back query: | |
| """ | |
| prompt = PromptTemplate( | |
| input_variables=["original_query"], | |
| template=template | |
| ) | |
| chain = prompt | llm | |
| return chain.invoke({"original_query": query}).content | |
| # Main query processing pipeline | |
| def process_query( | |
| original_query: str, | |
| llms: dict | |
| ) -> str: | |
| """ | |
| Process a query through the full pipeline: | |
| 1. Classify specificity | |
| 2. Apply appropriate refinement | |
| """ | |
| specificity = classify_query_specificity(original_query, llms["rewrite_llm"]) | |
| if specificity == "specific": | |
| return refine_query(original_query, llms["rewrite_llm"]) | |
| elif specificity == "general": | |
| return refine_query(original_query, llms["rewrite_llm"]) | |
| elif specificity == "too narrow": | |
| return generate_step_back_query(original_query, llms["step_back_llm"]) | |
| return original_query | |
| # Test setup | |
| def test_hydrogen_certification_functions(): | |
| # Initialize LLMs | |
| llms = initialize_llms() | |
| # Create a test directory with hydrogen certifications | |
| test_certs_dir = "docs/processed" | |
| os.makedirs(test_certs_dir, exist_ok=True) | |
| # Create some dummy certification folders | |
| hydrogen_certifications = [ | |
| "GH2_Standard", | |
| "Certified_Hydrogen_Producer", | |
| "Green_Hydrogen_Certification", | |
| "ISO_19880_Hydrogen_Quality" | |
| ] | |
| for cert in hydrogen_certifications: | |
| os.makedirs(os.path.join(test_certs_dir, cert), exist_ok=True) | |
| # Test queries | |
| test_queries = [ | |
| ("What are the purity requirements in GH2 Standard?", "specific"), | |
| ("How does hydrogen certification work?", "general"), | |
| ("What's the exact ppm of CO2 allowed in ISO_19880_Hydrogen_Quality section 4.2?", "too narrow"), | |
| ("What safety protocols exist for hydrogen storage?", "general") | |
| ] | |
| print("=== Testing Certification Classification ===") | |
| for query, _ in test_queries: | |
| cert = classify_certification(query, llms["rewrite_llm"], test_certs_dir) | |
| print(f"Query: {query}\nClassification: {cert}\n") | |
| print("\n=== Testing Specificity Classification ===") | |
| for query, expected_type in test_queries: | |
| specificity = classify_query_specificity(query, llms["rewrite_llm"]) | |
| print(f"Query: {query}\nExpected: {expected_type}, Got: {specificity}\n") | |
| print("\n=== Testing Full Query Processing ===") | |
| for query, _ in test_queries: | |
| processed = process_query(query, llms) | |
| print(f"Original: {query}\nProcessed: {processed}\n") | |
| # Run the tests | |
| if __name__ == "__main__": | |
| test_hydrogen_certification_functions() |