gbakidz commited on
Commit
074f4df
·
verified ·
1 Parent(s): 43b236c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -8
app.py CHANGED
@@ -16,7 +16,7 @@ torch.set_num_threads(2)
16
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
  model = AutoModelForCausalLM.from_pretrained(
18
  MODEL_NAME,
19
- torch_dtype=torch.float32,
20
  low_cpu_mem_usage=True
21
  )
22
 
@@ -24,6 +24,7 @@ model.to("cpu")
24
 
25
  print("Model loaded!")
26
 
 
27
  # -------- REQUEST SCHEMA --------
28
  class RequestData(BaseModel):
29
  prompt: str
@@ -31,6 +32,12 @@ class RequestData(BaseModel):
31
  use_search: bool = True
32
 
33
 
 
 
 
 
 
 
34
  # -------- TOOL 1: SEARCH --------
35
  def search_links(query):
36
  url = f"https://duckduckgo.com/html/?q={query}"
@@ -80,18 +87,22 @@ def browse_web(query):
80
  return "\n\n".join(contents[:3])
81
 
82
 
83
- # -------- MEMORY BUILDER --------
84
  def build_prompt(prompt, history):
85
  convo = ""
86
 
87
- for user, bot in history:
88
- convo += f"User: {user}\nAssistant: {bot}\n"
 
 
 
 
89
 
90
  convo += f"User: {prompt}\nAssistant:"
91
  return convo
92
 
93
 
94
- # -------- GENERATION --------
95
  def generate_text(prompt):
96
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
97
 
@@ -103,16 +114,17 @@ def generate_text(prompt):
103
  do_sample=True
104
  )
105
 
106
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
107
 
108
 
109
  # -------- AGENT LOOP --------
110
  def agent(prompt, history, use_search=True):
111
 
112
- # Step 1: Build conversation
113
  base_prompt = build_prompt(prompt, history)
114
 
115
- # Step 2: Decide if search is needed
116
  decision_prompt = f"""
117
  You are an AI agent.
118
 
 
16
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
17
  model = AutoModelForCausalLM.from_pretrained(
18
  MODEL_NAME,
19
+ dtype=torch.float32,
20
  low_cpu_mem_usage=True
21
  )
22
 
 
24
 
25
  print("Model loaded!")
26
 
27
+
28
  # -------- REQUEST SCHEMA --------
29
  class RequestData(BaseModel):
30
  prompt: str
 
32
  use_search: bool = True
33
 
34
 
35
+ # -------- ROOT ROUTE --------
36
+ @app.get("/")
37
+ def home():
38
+ return {"message": "API is running"}
39
+
40
+
41
  # -------- TOOL 1: SEARCH --------
42
  def search_links(query):
43
  url = f"https://duckduckgo.com/html/?q={query}"
 
87
  return "\n\n".join(contents[:3])
88
 
89
 
90
+ # -------- MEMORY BUILDER (FIXED) --------
91
  def build_prompt(prompt, history):
92
  convo = ""
93
 
94
+ for msg in history:
95
+ if isinstance(msg, dict):
96
+ if msg.get("role") == "user":
97
+ convo += f"User: {msg.get('content')}\n"
98
+ elif msg.get("role") == "assistant":
99
+ convo += f"Assistant: {msg.get('content')}\n"
100
 
101
  convo += f"User: {prompt}\nAssistant:"
102
  return convo
103
 
104
 
105
+ # -------- GENERATION (FIXED OUTPUT) --------
106
  def generate_text(prompt):
107
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
108
 
 
114
  do_sample=True
115
  )
116
 
117
+ full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
118
+
119
+ # Remove prompt from output
120
+ return full_text[len(prompt):].strip()
121
 
122
 
123
  # -------- AGENT LOOP --------
124
  def agent(prompt, history, use_search=True):
125
 
 
126
  base_prompt = build_prompt(prompt, history)
127
 
 
128
  decision_prompt = f"""
129
  You are an AI agent.
130