File size: 2,224 Bytes
7d1fdee
 
 
 
 
 
 
 
b7fcbb7
7d1fdee
 
 
 
 
 
 
 
 
 
b7fcbb7
7d1fdee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
from dotenv import load_dotenv
from tavily import TavilyClient
from openai import OpenAI
from cerebras.cloud.sdk import Cerebras

load_dotenv()

SAMPLE_MODEL = "gpt-oss-120b"

class HanBasicAgent:
    def __init__(self):
        self.hf_token = os.getenv("HF_TOKEN")
        self.tavily_key = os.getenv("TAVILY_API_KEY")
        assert self.hf_token and self.tavily_key

        self.tavily = TavilyClient(api_key=self.tavily_key)
        # self.llm_client = OpenAI(base_url="https://router.huggingface.co/v1", api_key=self.hf_token)
        self.llm_client = Cerebras(api_key=os.environ.get("CEREBRAS_API_KEY"))
        self.model = os.getenv("LLM_MODEL", SAMPLE_MODEL)

    def _truncate_query(self, query: str, max_len: int = 390) -> str:
        return query[:max_len] + ("..." if len(query) > max_len else "")


    def answer(self, question: str, mode="context") -> str:
        # Truncate BEFORE Tavily call
        truncated_question = self._truncate_query(question)

        # Use truncated_question for Tavily
        context = (
            self.tavily.get_search_context(query=truncated_question)
            if mode == "context"
            else self.tavily.qna_search(query=truncated_question)
        )

        if not context:
            context = "No context found. Answer based on your knowledge."

        # Use truncated_question for LLM too
        messages = [
            {"role": "system", "content": "You are a helpful assistant. Do not call tools. Use context faithfully."},
            {"role": "user", "content": f"Question: {truncated_question}\nContext:\n{context}\nAnswer concisely."}
        ]

        comp = self.llm_client.chat.completions.create(
            model=self.model,
            messages=messages,
        )

        return comp.choices[0].message.content.strip()


    def __call__(self, question: str) -> str:
        return self.answer(question)


if __name__ == "__main__":
    agent = HanBasicAgent()

    # Default (context mode)
    print(agent("Who founded Tavily?"))

    # Quick Q&A mode
    print(agent("What is the capital of France?", mode="qna"))

    # Return both context and answer
    print(agent("Explain Burning Man floods", return_context=True))