abanm commited on
Commit
3733252
·
verified ·
1 Parent(s): c78218b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -4
app.py CHANGED
@@ -6,6 +6,10 @@ from fastapi import FastAPI, Request, HTTPException
6
  from fastapi.responses import StreamingResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
 
 
 
 
 
9
  # Configure logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
@@ -27,15 +31,25 @@ app.add_middleware(
27
 
28
  # Load your API key from the environment (defaults to "change_me")
29
  API_KEY = os.environ.get("API_KEY", "change_me")
30
-
31
  logger.info(f"API key loaded: {API_KEY}")
32
 
33
  # URL of the running Ollama server (adjust as needed)
34
  OLLAMA_SERVER_URL = "http://localhost:11434/api/generate"
35
 
 
 
 
 
 
 
 
 
 
 
 
36
  @app.post("/api/generate")
37
  async def generate(request: Request):
38
- """Endpoint that generates text based on the prompt."""
39
  try:
40
  # 1. Parse the incoming request
41
  body = await request.json()
@@ -74,7 +88,7 @@ async def generate(request: Request):
74
  response.raise_for_status()
75
  async for chunk in response.aiter_text():
76
  yield chunk
77
- except httpx.RequestError as exc:
78
  logger.exception("Request error while communicating with Ollama")
79
  yield json.dumps({"error": "Unable to communicate with Ollama"})
80
  except httpx.HTTPStatusError as exc:
@@ -87,6 +101,38 @@ async def generate(request: Request):
87
  logger.exception("Unhandled exception in /api/generate")
88
  raise HTTPException(status_code=500, detail="Internal server error")
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  @app.get("/health")
91
  async def health():
92
  """Health check endpoint."""
@@ -97,4 +143,3 @@ if __name__ == "__main__":
97
  import uvicorn
98
  logger.info("Starting server on http://0.0.0.0:7860")
99
  uvicorn.run(app, host="0.0.0.0", port=7860)
100
-
 
6
  from fastapi.responses import StreamingResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
 
9
+ # --- New imports for LangChain + langchain_ollama ---
10
+ from pydantic import BaseModel
11
+ from langchain_ollama import Ollama
12
+
13
  # Configure logging
14
  logging.basicConfig(level=logging.INFO)
15
  logger = logging.getLogger(__name__)
 
31
 
32
  # Load your API key from the environment (defaults to "change_me")
33
  API_KEY = os.environ.get("API_KEY", "change_me")
 
34
  logger.info(f"API key loaded: {API_KEY}")
35
 
36
  # URL of the running Ollama server (adjust as needed)
37
  OLLAMA_SERVER_URL = "http://localhost:11434/api/generate"
38
 
39
+ # --- Initialize a single Ollama instance via langchain_ollama ---
40
+ # This is the "LangChain" style interface to the Ollama server.
41
+ ollama_llm = Ollama(
42
+ model="hf.co/abanm/Dubs-Q8_0-GGUF:latest", # same model as before
43
+ base_url="http://localhost:11434", # base URL for Ollama server
44
+ request_timeout=120 # Increase if model loading is slow
45
+ )
46
+
47
+ # --------------------------------
48
+ # Original endpoint: /api/generate
49
+ # --------------------------------
50
  @app.post("/api/generate")
51
  async def generate(request: Request):
52
+ """Endpoint that generates text based on the prompt (direct HTTP call to Ollama)."""
53
  try:
54
  # 1. Parse the incoming request
55
  body = await request.json()
 
88
  response.raise_for_status()
89
  async for chunk in response.aiter_text():
90
  yield chunk
91
+ except httpx.RequestError:
92
  logger.exception("Request error while communicating with Ollama")
93
  yield json.dumps({"error": "Unable to communicate with Ollama"})
94
  except httpx.HTTPStatusError as exc:
 
101
  logger.exception("Unhandled exception in /api/generate")
102
  raise HTTPException(status_code=500, detail="Internal server error")
103
 
104
+ # ------------------------------
105
+ # New endpoint: /api/langchain-generate
106
+ # ------------------------------
107
+ class LangChainRequest(BaseModel):
108
+ prompt: str
109
+
110
+ @app.post("/api/langchain-generate")
111
+ async def langchain_generate(request: LangChainRequest):
112
+ """
113
+ Endpoint that uses langchain_ollama to generate text.
114
+ This is an alternative approach that uses the Ollama() class from langchain_ollama.
115
+ """
116
+ # Check for API key (similar logic as above, or unify the code)
117
+ # ... or just omit it if your environment is already secure
118
+ # If re-using the same approach:
119
+ # (In real code, you'd unify these checks in a shared function)
120
+ #
121
+ # auth_header = ...
122
+ # if token != API_KEY:
123
+ # raise HTTPException(...)
124
+
125
+ prompt = request.prompt
126
+ logger.info(f"LangChain request: {prompt}")
127
+
128
+ # Directly call the Ollama LLM via langchain_ollama
129
+ try:
130
+ response_text = ollama_llm(prompt)
131
+ return {"response": response_text}
132
+ except Exception as e:
133
+ logger.exception("Unhandled exception in /api/langchain-generate")
134
+ raise HTTPException(status_code=500, detail=str(e))
135
+
136
  @app.get("/health")
137
  async def health():
138
  """Health check endpoint."""
 
143
  import uvicorn
144
  logger.info("Starting server on http://0.0.0.0:7860")
145
  uvicorn.run(app, host="0.0.0.0", port=7860)