|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
from tavily import TavilyClient |
|
|
from openai import OpenAI |
|
|
from cerebras.cloud.sdk import Cerebras |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
SAMPLE_MODEL = "gpt-oss-120b" |
|
|
|
|
|
class HanBasicAgent: |
|
|
def __init__(self): |
|
|
self.hf_token = os.getenv("HF_TOKEN") |
|
|
self.tavily_key = os.getenv("TAVILY_API_KEY") |
|
|
assert self.hf_token and self.tavily_key |
|
|
|
|
|
self.tavily = TavilyClient(api_key=self.tavily_key) |
|
|
|
|
|
self.llm_client = Cerebras(api_key=os.environ.get("CEREBRAS_API_KEY")) |
|
|
self.model = os.getenv("LLM_MODEL", SAMPLE_MODEL) |
|
|
|
|
|
def _truncate_query(self, query: str, max_len: int = 390) -> str: |
|
|
return query[:max_len] + ("..." if len(query) > max_len else "") |
|
|
|
|
|
|
|
|
def answer(self, question: str, mode="context") -> str: |
|
|
|
|
|
truncated_question = self._truncate_query(question) |
|
|
|
|
|
|
|
|
context = ( |
|
|
self.tavily.get_search_context(query=truncated_question) |
|
|
if mode == "context" |
|
|
else self.tavily.qna_search(query=truncated_question) |
|
|
) |
|
|
|
|
|
if not context: |
|
|
context = "No context found. Answer based on your knowledge." |
|
|
|
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": "You are a helpful assistant. Do not call tools. Use context faithfully."}, |
|
|
{"role": "user", "content": f"Question: {truncated_question}\nContext:\n{context}\nAnswer concisely."} |
|
|
] |
|
|
|
|
|
comp = self.llm_client.chat.completions.create( |
|
|
model=self.model, |
|
|
messages=messages, |
|
|
) |
|
|
|
|
|
return comp.choices[0].message.content.strip() |
|
|
|
|
|
|
|
|
def __call__(self, question: str) -> str: |
|
|
return self.answer(question) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
agent = HanBasicAgent() |
|
|
|
|
|
|
|
|
print(agent("Who founded Tavily?")) |
|
|
|
|
|
|
|
|
print(agent("What is the capital of France?", mode="qna")) |
|
|
|
|
|
|
|
|
print(agent("Explain Burning Man floods", return_context=True)) |
|
|
|