sharktide commited on
Commit
9162e29
·
verified ·
1 Parent(s): a90dfbb

Add prompt analyze endpoint

Browse files
Files changed (1) hide show
  1. gen.py +87 -0
gen.py CHANGED
@@ -835,3 +835,90 @@ async def generate_text(
835
  return JSONResponse(status_code=r.status_code, content=payload)
836
 
837
  raise HTTPException(500, "Unknown provider routing error")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
835
  return JSONResponse(status_code=r.status_code, content=payload)
836
 
837
  raise HTTPException(500, "Unknown provider routing error")
838
+
839
+ @router.post("/prompt_analyze")
840
+ async def analyze_prompt(
841
+ request: Request
842
+ ):
843
+ body = await request.json()
844
+ messages = body.get("prompt", [])
845
+ if not isinstance(messages, list) or len(messages) == 0:
846
+ raise HTTPException(400, "messages[] is required")
847
+
848
+ total_chars, total_bytes = calculate_messages_size(messages)
849
+ prompt_text = extract_user_text(messages)
850
+
851
+ uses_tools = (
852
+ "tools" in body and isinstance(body["tools"], list) and len(body["tools"]) > 0
853
+ ) or ("tool_choice" in body and body["tool_choice"] not in [None, "none"])
854
+
855
+ long_context = is_long_context(messages)
856
+ code_present = contains_code(prompt_text)
857
+ math_heavy = is_math_heavy(prompt_text)
858
+ structured_task = is_structured_task(prompt_text)
859
+ multi_q = multiple_questions(prompt_text)
860
+ code_heavy = is_code_heavy(prompt_text, code_present, long_context)
861
+
862
+ score = 0
863
+
864
+ if long_context:
865
+ score += 3
866
+
867
+ if math_heavy:
868
+ score += 3
869
+
870
+ if structured_task:
871
+ score += 2
872
+
873
+ if code_present:
874
+ score += 2
875
+
876
+ if multi_q:
877
+ score += 1
878
+
879
+ for kw in REASONING_KEYWORDS:
880
+ if kw in prompt_text:
881
+ score += 1
882
+
883
+ chosen_model = "llama-3.1-8b-instant"
884
+ provider = "groq"
885
+ has_images = contains_images(messages)
886
+
887
+ if has_images:
888
+ chosen_model = "gpt-4o-mini"
889
+ provider = "navy vision"
890
+ else:
891
+ if score > 10:
892
+ score = 10
893
+ if uses_tools:
894
+ if score >= 6:
895
+ chosen_model = "nemotron-3-super"
896
+ provider = "navy"
897
+ elif score >= 4:
898
+ chosen_model = "openai/gpt-oss-120b"
899
+ provider = "groq"
900
+ else:
901
+ chosen_model = "openai/gpt-oss-20b"
902
+ provider = "groq"
903
+
904
+ elif code_present:
905
+
906
+ if code_heavy and score >= 6:
907
+ chosen_model = "qwen-3-235b-a22b-instruct-2507"
908
+ provider = "cerebras"
909
+
910
+ elif score >= 4:
911
+ chosen_model = "llama-3.3-70b-versatile"
912
+ provider = "groq"
913
+
914
+ elif score >= 4:
915
+ chosen_model = "meta-llama/llama-4-scout-17b-16e-instruct"
916
+ provider = "groq"
917
+
918
+ if provider == "groq" and (
919
+ total_chars > MAX_GROQ_PROMPT_CHARS or total_bytes > MAX_GROQ_PROMPT_BYTES
920
+ ):
921
+ provider = "cerebras"
922
+ chosen_model = "qwen-3-235b-a22b-instruct-2507"
923
+
924
+ return { chosen_model, provider }