sharktide commited on
Commit
ef7feb8
·
verified ·
1 Parent(s): 3fa53f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +241 -44
app.py CHANGED
@@ -7,7 +7,7 @@ import httpx
7
  from bs4 import BeautifulSoup
8
  from typing import List, Dict
9
  import asyncio
10
-
11
  app = FastAPI()
12
 
13
  app.add_middleware(
@@ -23,6 +23,170 @@ RATE_LIMIT = 25
23
  WINDOW_SECONDS = 60 * 60 * 24
24
  ip_store = {} # { ip: { "count": int, "reset": timestamp } }
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def check_rate_limit(ip: str):
28
  now = time.time()
@@ -99,40 +263,19 @@ def check_chat_rate_limit(ip: str):
99
  entry["count"] += 1
100
  return entry["count"]
101
 
102
- def detect_tool_use(messages: list) -> bool:
103
- """
104
- Detect if the request uses tools.
105
- We check for:
106
- - presence of "tool_calls"
107
- - messages containing function_call-like structures
108
- """
109
- for m in messages:
110
- if "tool_calls" in m:
111
- return True
112
- if "function_call" in m:
113
- return True
114
- return False
115
-
116
-
117
- def choose_model(messages: list, msg_count: int):
118
- uses_tools = detect_tool_use(messages)
119
-
120
- if uses_tools:
121
- if msg_count > 20:
122
- return "openai/gpt-oss-120b", "groq"
123
- return "openai/gpt-oss-20b", "groq"
124
-
125
- if msg_count > 20:
126
- return "gpt-oss-120b", "cerebras"
127
-
128
- return "llama-3.1-8b-instant", "groq"
129
-
130
  @app.get("/genimg/{prompt}")
131
  async def generate_image(prompt: str, request: Request):
132
  client_ip = request.client.host
133
  check_rate_limit(client_ip)
134
 
135
- url = f"https://gen.pollinations.ai/image/{prompt}?model=zimage&key={PKEY}"
 
 
 
 
 
 
 
136
 
137
  async with httpx.AsyncClient() as client:
138
  response = await client.get(url)
@@ -189,30 +332,84 @@ async def generate_text(request: Request):
189
 
190
  ip = request.client.host
191
  msg_count = check_chat_rate_limit(ip)
192
-
 
193
  uses_tools = (
194
  "tools" in body and isinstance(body["tools"], list) and len(body["tools"]) > 0
195
  ) or ("tool_choice" in body and body["tool_choice"] not in [None, "none"])
196
-
197
- requested_model = body.get("model")
198
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  if uses_tools:
200
- if msg_count > 20:
 
201
  chosen_model = "openai/gpt-oss-120b"
202
  else:
203
  chosen_model = "openai/gpt-oss-20b"
204
  provider = "groq"
205
-
206
- else:
207
- if msg_count > 20:
208
- chosen_model = "gpt-oss-120b"
209
  provider = "cerebras"
210
- else:
211
- chosen_model = "llama-3.1-8b-instant"
212
- provider = "groq"
213
-
 
 
 
 
 
 
 
 
 
 
 
 
214
  body["model"] = chosen_model
215
- print(f"[TEXT GEN] INFO: Selected model: {chosen_model} {provider}")
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  stream = body.get("stream", False)
217
 
218
  if provider == "groq":
 
7
  from bs4 import BeautifulSoup
8
  from typing import List, Dict
9
  import asyncio
10
+ import re
11
  app = FastAPI()
12
 
13
  app.add_middleware(
 
23
  WINDOW_SECONDS = 60 * 60 * 24
24
  ip_store = {} # { ip: { "count": int, "reset": timestamp } }
25
 
26
+ REASONING_KEYWORDS = [
27
+ # explicit reasoning requests
28
+ "prove", "demonstrate", "derive", "justify", "verify",
29
+ "show that", "walk through", "step by step", "reason through",
30
+ "chain of reasoning", "rigorous", "formal proof",
31
+
32
+ # analysis/comparison
33
+ "analyze", "analysis of", "compare and contrast",
34
+ "evaluate", "critically assess", "explain why",
35
+ "explain how", "what causes", "implications of",
36
+
37
+ # problem solving
38
+ "solve", "solution to", "how would you approach",
39
+ "strategy for", "optimize", "algorithm for",
40
+
41
+ # technical domains
42
+ "theorem", "lemma", "corollary",
43
+ "complexity analysis", "big o", "time complexity",
44
+ "mathematical", "statistical", "probabilistic",
45
+ "model the", "simulate",
46
+ ]
47
+
48
+ CODE_KEYWORDS = [
49
+ "await", "async", "print(", "console.log(",
50
+ "code", ".ts", ".js", ".py", ".repy", ".rb",
51
+ "gnu", "gcc", "clang", "clang++", "program",
52
+ "coding"
53
+ ]
54
+
55
+ CREATIVE_KEYWORDS = [
56
+ # cinematic cues
57
+ "cinematic", "film still", "movie scene",
58
+ "epic", "dramatic lighting", "moody lighting",
59
+ "volumetric lighting", "depth of field",
60
+ "anamorphic lens", "8k", "4k",
61
+
62
+ # art styles
63
+ "concept art", "digital painting",
64
+ "fantasy art", "sci-fi", "mythical",
65
+ "cyberpunk", "steampunk",
66
+ "baroque", "surreal", "abstract",
67
+ "oil painting", "watercolor",
68
+
69
+ # rendering engines
70
+ "octane render", "unreal engine",
71
+ "ray tracing", "global illumination",
72
+
73
+ # emotional narrative framing
74
+ "emotional portrait", "story scene",
75
+ "hero shot", "dramatic pose",
76
+ ]
77
+
78
+ STRUCTURED_KEYWORDS = [
79
+ "return as json",
80
+ "output json",
81
+ "json schema",
82
+ "format as json",
83
+ "structured output",
84
+ "extract entities",
85
+ "extract fields",
86
+ "parse this",
87
+ "convert to table",
88
+ "create a table",
89
+ "categorize into",
90
+ "classify",
91
+ "label the following",
92
+ "taxonomy",
93
+ "generate schema",
94
+ ]
95
+
96
+ MATH_PATTERNS = [
97
+ r"\b∫\b", r"\b∑\b", r"\b∂\b",
98
+ r"\bmatrix\b",
99
+ r"\blimit\b",
100
+ r"\bintegral\b",
101
+ r"\bderivative\b",
102
+ r"\bdifferential equation\b",
103
+ r"\blinear algebra\b",
104
+ r"\boptimi[sz]e\b",
105
+ r"\bgradient\b",
106
+ r"\bbackprop\b",
107
+ r"\bproof\b",
108
+ r"\btheorem\b",
109
+ ]
110
+
111
+ LIGHTWEIGHT_KEYWORDS = [
112
+ "hello", "hi", "hey",
113
+ "thanks", "thank you",
114
+ "define", "definition of",
115
+ "what is", "who is",
116
+ "quick question",
117
+ "short answer",
118
+ "brief explanation",
119
+ "summarize",
120
+ "paraphrase",
121
+ "rewrite this",
122
+ ]
123
+
124
+ def is_long_context(messages: list) -> bool:
125
+ total_chars = sum(len(m.get("content", "")) for m in messages)
126
+ return total_chars > 4000
127
+
128
+
129
+ def contains_code(prompt: str) -> bool:
130
+ if "```" in prompt:
131
+ return True
132
+ for kw in CODE_KEYWORDS:
133
+ if kw in prompt:
134
+ return True
135
+ return False
136
+
137
+
138
+ def is_math_heavy(prompt: str) -> bool:
139
+ for pattern in MATH_PATTERNS:
140
+ if re.search(pattern, prompt):
141
+ return True
142
+ return False
143
+
144
+
145
+ def is_structured_task(prompt: str) -> bool:
146
+ for kw in STRUCTURED_KEYWORDS:
147
+ if kw in prompt:
148
+ return True
149
+ return False
150
+
151
+
152
+ def multiple_questions(prompt: str) -> bool:
153
+ return prompt.count("?") >= 3
154
+
155
+ def extract_user_text(messages: list) -> str:
156
+ return " ".join(
157
+ m.get("content", "")
158
+ for m in messages
159
+ if m.get("role") == "user"
160
+ ).lower()
161
+
162
+
163
+ def is_complex_reasoning(prompt: str) -> bool:
164
+ if len(prompt) > 800:
165
+ return True
166
+
167
+ for kw in REASONING_KEYWORDS:
168
+ if kw in prompt:
169
+ return True
170
+
171
+ if re.search(r"\b(if|therefore|assume|let x|given that)\b", prompt):
172
+ return True
173
+
174
+ return False
175
+
176
+
177
+ def is_lightweight(prompt: str) -> bool:
178
+ if len(prompt) < 100:
179
+ for kw in LIGHTWEIGHT_KEYWORDS:
180
+ if kw in prompt:
181
+ return True
182
+ return False
183
+
184
+
185
+ def is_cinematic_image_prompt(prompt: str) -> bool:
186
+ for kw in CREATIVE_KEYWORDS:
187
+ if kw in prompt.lower():
188
+ return True
189
+ return False
190
 
191
  def check_rate_limit(ip: str):
192
  now = time.time()
 
263
  entry["count"] += 1
264
  return entry["count"]
265
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  @app.get("/genimg/{prompt}")
267
  async def generate_image(prompt: str, request: Request):
268
  client_ip = request.client.host
269
  check_rate_limit(client_ip)
270
 
271
+ if is_cinematic_image_prompt(prompt):
272
+ chosen_model = "flux"
273
+ else:
274
+ chosen_model = "zimage"
275
+
276
+ print(f"[IMAGE GEN] Routing to model: {chosen_model}")
277
+
278
+ url = f"https://gen.pollinations.ai/image/{prompt}?model={chosen_model}&key={PKEY}"
279
 
280
  async with httpx.AsyncClient() as client:
281
  response = await client.get(url)
 
332
 
333
  ip = request.client.host
334
  msg_count = check_chat_rate_limit(ip)
335
+ prompt_text = extract_user_text(messages)
336
+
337
  uses_tools = (
338
  "tools" in body and isinstance(body["tools"], list) and len(body["tools"]) > 0
339
  ) or ("tool_choice" in body and body["tool_choice"] not in [None, "none"])
340
+
341
+ long_context = is_long_context(messages)
342
+ code_present = contains_code(prompt_text)
343
+ math_heavy = is_math_heavy(prompt_text)
344
+ structured_task = is_structured_task(prompt_text)
345
+ multi_q = multiple_questions(prompt_text)
346
+
347
+ score = 0
348
+
349
+ if long_context:
350
+ score += 3
351
+
352
+ if math_heavy:
353
+ score += 3
354
+
355
+ if structured_task:
356
+ score += 2
357
+
358
+ if code_present:
359
+ score += 2
360
+
361
+ if multi_q:
362
+ score += 1
363
+
364
+ for kw in REASONING_KEYWORDS:
365
+ if kw in prompt_text:
366
+ score += 1
367
+
368
+ chosen_model = "llama-3.1-8b-instant"
369
+ provider = "groq"
370
+
371
  if uses_tools:
372
+ # tools always need reliability
373
+ if score >= 4:
374
  chosen_model = "openai/gpt-oss-120b"
375
  else:
376
  chosen_model = "openai/gpt-oss-20b"
377
  provider = "groq"
378
+ elif code_heavy:
379
+ if score >= 6:
380
+ chosen_model = "zai-glm-4.7"
 
381
  provider = "cerebras"
382
+ elif score >= 6:
383
+ # extreme reasoning
384
+ chosen_model = "gpt-oss-120b"
385
+ provider = "cerebras"
386
+
387
+ elif score >= 4:
388
+ # medium-high reasoning
389
+ chosen_model = "llama-3.3-70b-versatile"
390
+ provider = "groq"
391
+
392
+ elif score >= 3 and structured_task:
393
+ chosen_model = "qwen-3-235b-a22b-instruct-2507"
394
+ provider = "cerebras"
395
+
396
+ # else → stay instant
397
+
398
  body["model"] = chosen_model
399
+
400
+ print(f"""
401
+ [ADVANCED ROUTER]
402
+ Score: {score}
403
+ Uses tools: {uses_tools}
404
+ Long context: {long_context}
405
+ Code present: {code_present}
406
+ Math heavy: {math_heavy}
407
+ Structured: {structured_task}
408
+ Multi-question: {multi_q}
409
+ → Selected: {chosen_model} ({provider})
410
+ """)
411
+
412
+
413
  stream = body.get("stream", False)
414
 
415
  if provider == "groq":