sharktide commited on
Commit
48286af
·
verified ·
1 Parent(s): a5bd7b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -16
app.py CHANGED
@@ -18,11 +18,8 @@ app.add_middleware(
18
 
19
  OLLAMA_LIBRARY_URL = "https://ollama.com/library"
20
 
21
- # -----------------------------
22
- # RATE LIMITING (25 req/day/IP)
23
- # -----------------------------
24
  RATE_LIMIT = 25
25
- WINDOW_SECONDS = 60 * 60 * 24 # 24 hours
26
  ip_store = {} # { ip: { "count": int, "reset": timestamp } }
27
 
28
 
@@ -46,12 +43,65 @@ def check_rate_limit(ip: str):
46
 
47
  entry["count"] += 1
48
 
49
-
50
- # -----------------------------
51
- # IMAGE GENERATION ENDPOINT
52
- # -----------------------------
53
- PKEY = os.getenv("POLLINATIONS_KEY", "") # ensure this is set in your environment
54
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  @app.get("/genimg/{prompt}")
57
  async def generate_image(prompt: str, request: Request):
@@ -69,17 +119,11 @@ async def generate_image(prompt: str, request: Request):
69
  detail=f"Pollinations error: {response.status_code}"
70
  )
71
 
72
- # Pollinations always returns JPEG
73
  return Response(
74
  content=response.content,
75
  media_type="image/jpeg"
76
  )
77
 
78
-
79
-
80
- # -----------------------------
81
- # EXISTING MODELS SCRAPER
82
- # -----------------------------
83
  @app.get("/models")
84
  async def get_models() -> List[Dict]:
85
  async with httpx.AsyncClient() as client:
@@ -110,3 +154,77 @@ async def get_models() -> List[Dict]:
110
  })
111
 
112
  return models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  OLLAMA_LIBRARY_URL = "https://ollama.com/library"
20
 
 
 
 
21
  RATE_LIMIT = 25
22
+ WINDOW_SECONDS = 60 * 60 * 24
23
  ip_store = {} # { ip: { "count": int, "reset": timestamp } }
24
 
25
 
 
43
 
44
  entry["count"] += 1
45
 
46
+ PKEY = os.getenv("POLLINATIONS_KEY", "")
47
+
48
+ message_counts = {}
49
+
50
+ def increment_message_count(ip: str):
51
+ message_counts[ip] = message_counts.get(ip, 0) + 1
52
+ return message_counts[ip]
53
+
54
+ GROQ_TOOL_MODELS = [
55
+ "openai/gpt-oss-120b",
56
+ "openai/gpt-oss-20b",
57
+ "meta-llama/llama-4-scout-17b-16e-instruct",
58
+ "qwen/qwen3-32b",
59
+ "moonshotai/kimi-k2-instruct",
60
+ ]
61
+
62
+ GROQ_NORMAL_MODELS = [
63
+ "llama-3.1-8b-instant",
64
+ "llama-3.3-70b-versatile",
65
+ "meta-llama/llama-4-maverick-17b-128e-instruct",
66
+ "meta-llama/llama-guard-4-12b",
67
+ "openai/gpt-oss-safeguard-20b",
68
+ "qwen/qwen3-32b",
69
+ ]
70
+
71
+ CEREBRAS_MODELS = [
72
+ "gpt-oss-120b",
73
+ "llama3.1-8b",
74
+ "qwen-3-235b-a22b-instruct-2507",
75
+ "zai-glm-4.7",
76
+ ]
77
+
78
+ def detect_tool_use(messages: list) -> bool:
79
+ """
80
+ Detect if the request uses tools.
81
+ We check for:
82
+ - presence of "tool_calls"
83
+ - messages containing function_call-like structures
84
+ """
85
+ for m in messages:
86
+ if "tool_calls" in m:
87
+ return True
88
+ if "function_call" in m:
89
+ return True
90
+ return False
91
+
92
+
93
+ def choose_model(messages: list, msg_count: int):
94
+ uses_tools = detect_tool_use(messages)
95
+
96
+ if uses_tools:
97
+ if msg_count > 20:
98
+ return "openai/gpt-oss-120b", "groq"
99
+ return "openai/gpt-oss-20b", "groq"
100
+
101
+ if msg_count > 20:
102
+ return "gpt-oss-120b", "cerebras"
103
+
104
+ return "llama-3.1-8b-instant", "groq"
105
 
106
  @app.get("/genimg/{prompt}")
107
  async def generate_image(prompt: str, request: Request):
 
119
  detail=f"Pollinations error: {response.status_code}"
120
  )
121
 
 
122
  return Response(
123
  content=response.content,
124
  media_type="image/jpeg"
125
  )
126
 
 
 
 
 
 
127
  @app.get("/models")
128
  async def get_models() -> List[Dict]:
129
  async with httpx.AsyncClient() as client:
 
154
  })
155
 
156
  return models
157
+
158
+ @app.post("/gen")
159
+ async def generate_text(request: Request):
160
+ body = await request.json()
161
+
162
+ messages = body.get("messages", [])
163
+ if not isinstance(messages, list) or len(messages) == 0:
164
+ raise HTTPException(400, "messages[] is required")
165
+
166
+ ip = request.client.host
167
+ msg_count = increment_message_count(ip)
168
+
169
+ uses_tools = (
170
+ "tools" in body and isinstance(body["tools"], list) and len(body["tools"]) > 0
171
+ ) or ("tool_choice" in body and body["tool_choice"] not in [None, "none"])
172
+
173
+ requested_model = body.get("model")
174
+
175
+ if uses_tools:
176
+ if msg_count > 20:
177
+ chosen_model = "openai/gpt-oss-120b"
178
+ else:
179
+ chosen_model = "openai/gpt-oss-20b"
180
+ provider = "groq"
181
+
182
+ else:
183
+ if msg_count > 20:
184
+ chosen_model = "gpt-oss-120b"
185
+ provider = "cerebras"
186
+ else:
187
+ chosen_model = "llama-3.1-8b-instant"
188
+ provider = "groq"
189
+
190
+ body["model"] = chosen_model
191
+
192
+ # -----------------------------
193
+ # GROQ FORWARDING
194
+ # -----------------------------
195
+ if provider == "groq":
196
+ GROQ_API_KEY = os.getenv("GROQ_KEY", "")
197
+ if not GROQ_API_KEY:
198
+ raise HTTPException(500, "Missing GROQ_KEY")
199
+
200
+ url = "https://api.groq.com/openai/v1/chat/completions"
201
+ headers = {"Authorization": f"Bearer {GROQ_API_KEY}"}
202
+
203
+ async with httpx.AsyncClient(timeout=None) as client:
204
+ r = await client.post(url, json=body, headers=headers)
205
+
206
+ return JSONResponse(
207
+ status_code=r.status_code,
208
+ content=r.json()
209
+ )
210
+
211
+ # -----------------------------
212
+ # CEREBRAS FORWARDING
213
+ # -----------------------------
214
+ if provider == "cerebras":
215
+ CEREBRAS_API_KEY = os.getenv("CER_KEY", "")
216
+ if not CEREBRAS_API_KEY:
217
+ raise HTTPException(500, "Missing CER_KEY")
218
+
219
+ url = "https://api.cerebras.ai/v1/chat/completions"
220
+ headers = {"Authorization": f"Bearer {CEREBRAS_API_KEY}"}
221
+
222
+ async with httpx.AsyncClient(timeout=None) as client:
223
+ r = await client.post(url, json=body, headers=headers)
224
+
225
+ return JSONResponse(
226
+ status_code=r.status_code,
227
+ content=r.json()
228
+ )
229
+
230
+ raise HTTPException(500, "Unknown provider routing error")