from langchain_core.prompts import PromptTemplate from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace import os from langchain_core.output_parsers import PydanticOutputParser from langchain_core.exceptions import OutputParserException import re import json from app.schema import OutputResponse if not os.getenv("HUGGINGFACEHUB_API_TOKEN"): print("WARNING: HUGGINGFACEHUB_API_TOKEN is not set in the environment. " "Set this as a secret in your HuggingFace Space or .env file.") def load_model(repo_id: str, max_length: int = 512, temperature: float = 0.5): # Retrieve the token (Hugging Face Spaces automatically exposes HF_TOKEN if you link your space to your token) hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN") if not hf_token: print("CRITICAL ERROR: No API Token found! Make sure you added your HUGGINGFACEHUB_API_TOKEN or HF_TOKEN in Space Secrets.") llm = HuggingFaceEndpoint( repo_id=repo_id, huggingfacehub_api_token=hf_token, task="text-generation", max_new_tokens=max_length, do_sample=temperature > 0, temperature=temperature if temperature > 0 else None, ) return llm def generate_answer(question: str, llm) -> OutputResponse: try: parser = PydanticOutputParser(pydantic_object=OutputResponse) prompt = PromptTemplate( input_variables=["question"], partial_variables={"format_instructions": parser.get_format_instructions()}, template="""You are a helpful assistant that provides concise and accurate answers. Question: {question} You MUST respond strictly in valid JSON format matching the schema below. Do not include any formatting or explanations outside the JSON object. {format_instructions} """ ) # HuggingFaceEndpoint returns a raw string chain = prompt | llm raw_output = chain.invoke({"question": question}) try: result = parser.invoke(raw_output) return result except OutputParserException: # Fallback for models that insist on wrapping in markdown ```json ... ``` match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', raw_output, re.DOTALL) if match: data = json.loads(match.group(1)) return OutputResponse(**data) return OutputResponse(answer="Output format failed", justification=str(raw_output)) except Exception as e: return OutputResponse(answer="Error generating answer", justification=str(e))