simoncck commited on
Commit
ae29907
·
verified ·
1 Parent(s): 481178d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -14
app.py CHANGED
@@ -20,9 +20,7 @@ def make_llm():
20
  """
21
  provider = os.getenv("LLM_PROVIDER", "gemini").lower()
22
 
23
- # LangChain helper; will raise if GOOGLE_API_KEY missing
24
- # return ChatGoogleGenerativeAI(model="gemini-2.0-flash-exp")
25
-
26
  if provider == "gemini":
27
  return ChatGoogleGenerativeAI(
28
  model=os.getenv("GEMINI_MODEL", "gemini-2.0-flash"),
@@ -31,23 +29,21 @@ def make_llm():
31
  timeout=60,
32
  )
33
 
 
34
  elif provider in {"azure", "azure-openai"}:
35
- return ChatOpenAI(
36
- openai_api_base=os.getenv("AZURE_OPENAI_ENDPOINT"), # NEW NAME
37
- openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-02-15-preview"),
38
- azure_deployment_name=os.getenv("AZURE_OPENAI_DEPLOYMENT", "gpt-35-turbo"),
39
- openai_api_key=os.getenv("AZURE_OPENAI_API_KEY"),
40
- openai_api_type="azure", # required flag
 
41
  temperature=float(os.getenv("AZURE_OPENAI_TEMPERATURE", "0.3")),
42
  max_retries=3,
43
  timeout=60,
44
  )
45
 
46
- else:
47
- raise ValueError(
48
- f"Unsupported LLM_PROVIDER '{provider}'. "
49
- "Use 'gemini' or 'azure-openai'."
50
- )
51
 
52
  @app.post("/run")
53
  async def run_task(t: Task):
 
20
  """
21
  provider = os.getenv("LLM_PROVIDER", "gemini").lower()
22
 
23
+ # --- Google Gemini stays exactly as before ------------------- #
 
 
24
  if provider == "gemini":
25
  return ChatGoogleGenerativeAI(
26
  model=os.getenv("GEMINI_MODEL", "gemini-2.0-flash"),
 
29
  timeout=60,
30
  )
31
 
32
+ # --- Azure OpenAI ------------------------------------------- #
33
  elif provider in {"azure", "azure-openai"}:
34
+ return AzureChatOpenAI( # ✅ correct wrapper
35
+ azure_deployment=os.getenv( # correct kwarg name
36
+ "AZURE_OPENAI_DEPLOYMENT", "gpt-35-turbo"
37
+ ),
38
+ api_version=os.getenv( # ✅ correct kwarg name
39
+ "AZURE_OPENAI_API_VERSION", "2024-05-01-preview"
40
+ ),
41
  temperature=float(os.getenv("AZURE_OPENAI_TEMPERATURE", "0.3")),
42
  max_retries=3,
43
  timeout=60,
44
  )
45
 
46
+ raise ValueError(f"Unsupported LLM_PROVIDER '{provider}'")
 
 
 
 
47
 
48
  @app.post("/run")
49
  async def run_task(t: Task):