File size: 2,652 Bytes
a122f91
 
 
9cace32
 
 
 
a122f91
 
 
 
 
 
 
e169104
 
 
 
 
 
a122f91
 
e169104
a122f91
 
 
 
1c56f8e
a122f91
 
a2c5f90
a122f91
 
a2c5f90
a122f91
 
a2c5f90
 
a122f91
a2c5f90
 
 
a122f91
 
a2c5f90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a122f91
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
from langchain_core.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
import os
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.exceptions import OutputParserException
import re
import json
from app.schema import OutputResponse

if not os.getenv("HUGGINGFACEHUB_API_TOKEN"):
    print("WARNING: HUGGINGFACEHUB_API_TOKEN is not set in the environment. "
          "Set this as a secret in your HuggingFace Space or .env file.")

def load_model(repo_id: str, max_length: int = 512, temperature: float = 0.5):
    # Retrieve the token (Hugging Face Spaces automatically exposes HF_TOKEN if you link your space to your token)
    hf_token = os.getenv("HUGGINGFACEHUB_API_TOKEN") or os.getenv("HF_TOKEN")
    
    if not hf_token:
        print("CRITICAL ERROR: No API Token found! Make sure you added your HUGGINGFACEHUB_API_TOKEN or HF_TOKEN in Space Secrets.")
        
    llm = HuggingFaceEndpoint(
        repo_id=repo_id,
        huggingfacehub_api_token=hf_token,
        task="text-generation",
        max_new_tokens=max_length,
        do_sample=temperature > 0, 
        temperature=temperature if temperature > 0 else None,
    )
    return llm


def generate_answer(question: str, llm) -> OutputResponse:
    try:
        parser = PydanticOutputParser(pydantic_object=OutputResponse)
        prompt = PromptTemplate(
            input_variables=["question"],
            partial_variables={"format_instructions": parser.get_format_instructions()},
            template="""You are a helpful assistant that provides concise and accurate answers.
Question: {question}

You MUST respond strictly in valid JSON format matching the schema below. Do not include any formatting or explanations outside the JSON object.
{format_instructions}
"""
        )
        
        # HuggingFaceEndpoint returns a raw string
        chain = prompt | llm
        raw_output = chain.invoke({"question": question})
        
        try:
            result = parser.invoke(raw_output)
            return result
        except OutputParserException:
            # Fallback for models that insist on wrapping in markdown ```json ... ```
            match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', raw_output, re.DOTALL)
            if match:
                data = json.loads(match.group(1))
                return OutputResponse(**data)
                
            return OutputResponse(answer="Output format failed", justification=str(raw_output))
            
    except Exception as e:
        return OutputResponse(answer="Error generating answer", justification=str(e))