boqapi / app /model.py
Dinuk-Di
Model Updated
9cace32
from langchain_core.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
import os
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.exceptions import OutputParserException
import re
import json
from app.schema import OutputResponse
if not os.getenv("HUGGINGFACEHUB_API_TOKEN"):
print("WARNING: HUGGINGFACEHUB_API_TOKEN is not set in the environment. "
"Set this as a secret in your HuggingFace Space or .env file.")
def load_model(repo_id: str, max_length: int = 512, temperature: float = 0.5):
# Retrieve the token (Hugging Face Spaces automatically exposes HF_TOKEN if you link your space to your token)
hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN")
if not hf_token:
print("CRITICAL ERROR: No API Token found! Make sure you added your HUGGINGFACEHUB_API_TOKEN or HF_TOKEN in Space Secrets.")
llm = HuggingFaceEndpoint(
repo_id=repo_id,
huggingfacehub_api_token=hf_token,
task="text-generation",
max_new_tokens=max_length,
do_sample=temperature > 0,
temperature=temperature if temperature > 0 else None,
)
return llm
def generate_answer(question: str, llm) -> OutputResponse:
try:
parser = PydanticOutputParser(pydantic_object=OutputResponse)
prompt = PromptTemplate(
input_variables=["question"],
partial_variables={"format_instructions": parser.get_format_instructions()},
template="""You are a helpful assistant that provides concise and accurate answers.
Question: {question}
You MUST respond strictly in valid JSON format matching the schema below. Do not include any formatting or explanations outside the JSON object.
{format_instructions}
"""
)
# HuggingFaceEndpoint returns a raw string
chain = prompt | llm
raw_output = chain.invoke({"question": question})
try:
result = parser.invoke(raw_output)
return result
except OutputParserException:
# Fallback for models that insist on wrapping in markdown ```json ... ```
match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', raw_output, re.DOTALL)
if match:
data = json.loads(match.group(1))
return OutputResponse(**data)
return OutputResponse(answer="Output format failed", justification=str(raw_output))
except Exception as e:
return OutputResponse(answer="Error generating answer", justification=str(e))