Update app.py
Browse files
app.py
CHANGED
|
@@ -6,6 +6,10 @@ from fastapi import FastAPI, Request, HTTPException
|
|
| 6 |
from fastapi.responses import StreamingResponse
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
# Configure logging
|
| 10 |
logging.basicConfig(level=logging.INFO)
|
| 11 |
logger = logging.getLogger(__name__)
|
|
@@ -27,15 +31,25 @@ app.add_middleware(
|
|
| 27 |
|
| 28 |
# Load your API key from the environment (defaults to "change_me")
|
| 29 |
API_KEY = os.environ.get("API_KEY", "change_me")
|
| 30 |
-
|
| 31 |
logger.info(f"API key loaded: {API_KEY}")
|
| 32 |
|
| 33 |
# URL of the running Ollama server (adjust as needed)
|
| 34 |
OLLAMA_SERVER_URL = "http://localhost:11434/api/generate"
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
@app.post("/api/generate")
|
| 37 |
async def generate(request: Request):
|
| 38 |
-
"""Endpoint that generates text based on the prompt."""
|
| 39 |
try:
|
| 40 |
# 1. Parse the incoming request
|
| 41 |
body = await request.json()
|
|
@@ -74,7 +88,7 @@ async def generate(request: Request):
|
|
| 74 |
response.raise_for_status()
|
| 75 |
async for chunk in response.aiter_text():
|
| 76 |
yield chunk
|
| 77 |
-
except httpx.RequestError
|
| 78 |
logger.exception("Request error while communicating with Ollama")
|
| 79 |
yield json.dumps({"error": "Unable to communicate with Ollama"})
|
| 80 |
except httpx.HTTPStatusError as exc:
|
|
@@ -87,6 +101,38 @@ async def generate(request: Request):
|
|
| 87 |
logger.exception("Unhandled exception in /api/generate")
|
| 88 |
raise HTTPException(status_code=500, detail="Internal server error")
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
@app.get("/health")
|
| 91 |
async def health():
|
| 92 |
"""Health check endpoint."""
|
|
@@ -97,4 +143,3 @@ if __name__ == "__main__":
|
|
| 97 |
import uvicorn
|
| 98 |
logger.info("Starting server on http://0.0.0.0:7860")
|
| 99 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 100 |
-
|
|
|
|
| 6 |
from fastapi.responses import StreamingResponse
|
| 7 |
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
|
| 9 |
+
# --- New imports for LangChain + langchain_ollama ---
|
| 10 |
+
from pydantic import BaseModel
|
| 11 |
+
from langchain_ollama import Ollama
|
| 12 |
+
|
| 13 |
# Configure logging
|
| 14 |
logging.basicConfig(level=logging.INFO)
|
| 15 |
logger = logging.getLogger(__name__)
|
|
|
|
| 31 |
|
| 32 |
# Load your API key from the environment (defaults to "change_me")
|
| 33 |
API_KEY = os.environ.get("API_KEY", "change_me")
|
|
|
|
| 34 |
logger.info(f"API key loaded: {API_KEY}")
|
| 35 |
|
| 36 |
# URL of the running Ollama server (adjust as needed)
|
| 37 |
OLLAMA_SERVER_URL = "http://localhost:11434/api/generate"
|
| 38 |
|
| 39 |
+
# --- Initialize a single Ollama instance via langchain_ollama ---
|
| 40 |
+
# This is the "LangChain" style interface to the Ollama server.
|
| 41 |
+
ollama_llm = Ollama(
|
| 42 |
+
model="hf.co/abanm/Dubs-Q8_0-GGUF:latest", # same model as before
|
| 43 |
+
base_url="http://localhost:11434", # base URL for Ollama server
|
| 44 |
+
request_timeout=120 # Increase if model loading is slow
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# --------------------------------
|
| 48 |
+
# Original endpoint: /api/generate
|
| 49 |
+
# --------------------------------
|
| 50 |
@app.post("/api/generate")
|
| 51 |
async def generate(request: Request):
|
| 52 |
+
"""Endpoint that generates text based on the prompt (direct HTTP call to Ollama)."""
|
| 53 |
try:
|
| 54 |
# 1. Parse the incoming request
|
| 55 |
body = await request.json()
|
|
|
|
| 88 |
response.raise_for_status()
|
| 89 |
async for chunk in response.aiter_text():
|
| 90 |
yield chunk
|
| 91 |
+
except httpx.RequestError:
|
| 92 |
logger.exception("Request error while communicating with Ollama")
|
| 93 |
yield json.dumps({"error": "Unable to communicate with Ollama"})
|
| 94 |
except httpx.HTTPStatusError as exc:
|
|
|
|
| 101 |
logger.exception("Unhandled exception in /api/generate")
|
| 102 |
raise HTTPException(status_code=500, detail="Internal server error")
|
| 103 |
|
| 104 |
+
# ------------------------------
|
| 105 |
+
# New endpoint: /api/langchain-generate
|
| 106 |
+
# ------------------------------
|
| 107 |
+
class LangChainRequest(BaseModel):
|
| 108 |
+
prompt: str
|
| 109 |
+
|
| 110 |
+
@app.post("/api/langchain-generate")
|
| 111 |
+
async def langchain_generate(request: LangChainRequest):
|
| 112 |
+
"""
|
| 113 |
+
Endpoint that uses langchain_ollama to generate text.
|
| 114 |
+
This is an alternative approach that uses the Ollama() class from langchain_ollama.
|
| 115 |
+
"""
|
| 116 |
+
# Check for API key (similar logic as above, or unify the code)
|
| 117 |
+
# ... or just omit it if your environment is already secure
|
| 118 |
+
# If re-using the same approach:
|
| 119 |
+
# (In real code, you'd unify these checks in a shared function)
|
| 120 |
+
#
|
| 121 |
+
# auth_header = ...
|
| 122 |
+
# if token != API_KEY:
|
| 123 |
+
# raise HTTPException(...)
|
| 124 |
+
|
| 125 |
+
prompt = request.prompt
|
| 126 |
+
logger.info(f"LangChain request: {prompt}")
|
| 127 |
+
|
| 128 |
+
# Directly call the Ollama LLM via langchain_ollama
|
| 129 |
+
try:
|
| 130 |
+
response_text = ollama_llm(prompt)
|
| 131 |
+
return {"response": response_text}
|
| 132 |
+
except Exception as e:
|
| 133 |
+
logger.exception("Unhandled exception in /api/langchain-generate")
|
| 134 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 135 |
+
|
| 136 |
@app.get("/health")
|
| 137 |
async def health():
|
| 138 |
"""Health check endpoint."""
|
|
|
|
| 143 |
import uvicorn
|
| 144 |
logger.info("Starting server on http://0.0.0.0:7860")
|
| 145 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|