AlessandroMasala commited on
Commit
af34c50
·
verified ·
1 Parent(s): 3400f74

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +87 -8
agent.py CHANGED
@@ -3,27 +3,106 @@ import os
3
  import re
4
  import ast
5
  import requests
 
6
  import tempfile
 
7
  from pathlib import Path
8
  from enum import Enum
9
- from typing import Optional, Dict, List, Any, TypedDict
10
 
11
  from dotenv import load_dotenv
12
  from langgraph.graph import StateGraph, END
13
- from langchain.tools import Tool as LangTool
14
  from langchain_core.runnables import RunnableLambda
15
- from langchain.tools import StructuredTool
16
 
17
- # --- Integrazione Zephyr LLM ---
18
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
19
  from langchain_community.llms import HuggingFacePipeline
20
  import torch
21
 
22
- #
23
-
24
-
25
-
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
 
 
3
  import re
4
  import ast
5
  import requests
6
+ import time
7
  import tempfile
8
+
9
  from pathlib import Path
10
  from enum import Enum
11
+ from typing import TypedDict, Dict, List, Any
12
 
13
  from dotenv import load_dotenv
14
  from langgraph.graph import StateGraph, END
15
+ from langchain.tools import Tool as LangTool, StructuredTool
16
  from langchain_core.runnables import RunnableLambda
 
17
 
18
+ # ——— Zephyr LLM setup ———
19
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
20
  from langchain_community.llms import HuggingFacePipeline
21
  import torch
22
 
 
 
 
 
23
 
24
+ # Function to get the desired model Zephyr
25
+ def get_zephyr_llm():
26
+ model_id = 'HuggingFaceH4/zephyr-7b-alpha'
27
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_id,
30
+ torch_dtype = torch.float16,
31
+ device_map = 'auto'
32
+ )
33
+ gen = pipeline(
34
+ 'text-generation',
35
+ model = model,
36
+ tokenizer = tokenizer,
37
+ max_new_tokens = 256,
38
+ temperature = 0.7,
39
+ top_p = 0.9
40
+ )
41
+ return HuggingFacePipeline(pipeline=gen)
42
+
43
+ # Declare the llm
44
+ llm = get_zephyr_llm()
45
+
46
+
47
+ # HF APIs
48
+ load_dotenv()
49
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
50
+ QUESTIONS_URL = f"{DEFAULT_API_URL}/questions"
51
+ SUBMIT_URL = f"{DEFAULT_API_URL}/submit"
52
+ FILE_PATH = f"{DEFAULT_API_URL}/files/"
53
+
54
+ # --- TOOLS --- #
55
+
56
+
57
+
58
+ # Agent Steps Class
59
+ class AgentStep(Enum):
60
+ ANALYZE = "analyze"
61
+ SELECT_TOOLS = "select_tools"
62
+ EXECUTE_TOOLS = "execute_tools"
63
+ SYNTHESIZE = "synthesize_answer"
64
+ ERROR_RECOVERY = "error_recovery"
65
+ COMPLETE = "complete"
66
+
67
+
68
+
69
+ # Agent State Class
70
+ class AgentState(TypedDict):
71
+ question: str
72
+ original_question: str
73
+ selected_tools: List[str]
74
+ tool_results: Dict[str, Any]
75
+ final_answer: str
76
+ current_step: str
77
+ error_count: int
78
+ max_errors: int
79
+
80
+
81
+ # Initialize state
82
+ def initialize_state(question: str) -> str:
83
+ return {
84
+ 'question': question,
85
+ 'original_question': question,
86
+ 'selected_tools': [],
87
+ 'tool_result': {},
88
+ 'final_answer': '',
89
+ 'current_step': AgentStep.ANALYZE.value,
90
+ "error_count": 0,
91
+ "max_errors": 3
92
+ }
93
+
94
+
95
+ def analyze_question(state: AgentState) -> AgentState:
96
+ state["current_step"] = AgentStep.SELECT_TOOLS.value
97
+ return state
98
+
99
+ # Agent Class
100
+ class GaiaAgent:
101
+ def __init__(self):
102
+ self.graph = graph
103
+
104
+ def __call__(self, task_id: str, question: str) -> str:
105
+
106
 
107
 
108