Spaces:
Sleeping
Sleeping
| from langchain_core.prompts import PromptTemplate | |
| from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace | |
| import os | |
| from langchain_core.output_parsers import PydanticOutputParser | |
| from langchain_core.exceptions import OutputParserException | |
| import re | |
| import json | |
| from app.schema import OutputResponse | |
| if not os.getenv("HUGGINGFACEHUB_API_TOKEN"): | |
| print("WARNING: HUGGINGFACEHUB_API_TOKEN is not set in the environment. " | |
| "Set this as a secret in your HuggingFace Space or .env file.") | |
| def load_model(repo_id: str, max_length: int = 512, temperature: float = 0.5): | |
| # Retrieve the token (Hugging Face Spaces automatically exposes HF_TOKEN if you link your space to your token) | |
| hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN") | |
| if not hf_token: | |
| print("CRITICAL ERROR: No API Token found! Make sure you added your HUGGINGFACEHUB_API_TOKEN or HF_TOKEN in Space Secrets.") | |
| llm = HuggingFaceEndpoint( | |
| repo_id=repo_id, | |
| huggingfacehub_api_token=hf_token, | |
| task="text-generation", | |
| max_new_tokens=max_length, | |
| do_sample=temperature > 0, | |
| temperature=temperature if temperature > 0 else None, | |
| ) | |
| return llm | |
| def generate_answer(question: str, llm) -> OutputResponse: | |
| try: | |
| parser = PydanticOutputParser(pydantic_object=OutputResponse) | |
| prompt = PromptTemplate( | |
| input_variables=["question"], | |
| partial_variables={"format_instructions": parser.get_format_instructions()}, | |
| template="""You are a helpful assistant that provides concise and accurate answers. | |
| Question: {question} | |
| You MUST respond strictly in valid JSON format matching the schema below. Do not include any formatting or explanations outside the JSON object. | |
| {format_instructions} | |
| """ | |
| ) | |
| # HuggingFaceEndpoint returns a raw string | |
| chain = prompt | llm | |
| raw_output = chain.invoke({"question": question}) | |
| try: | |
| result = parser.invoke(raw_output) | |
| return result | |
| except OutputParserException: | |
| # Fallback for models that insist on wrapping in markdown ```json ... ``` | |
| match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', raw_output, re.DOTALL) | |
| if match: | |
| data = json.loads(match.group(1)) | |
| return OutputResponse(**data) | |
| return OutputResponse(answer="Output format failed", justification=str(raw_output)) | |
| except Exception as e: | |
| return OutputResponse(answer="Error generating answer", justification=str(e)) |