import logging from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from pydantic import BaseModel from typing import Generator from langchain_community.llms import Ollama from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.manager import CallbackManager logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], # You can restrict this in production allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class QueryWithContext(BaseModel): question: str context: str model: str = "tinyllama" # default model # Streaming callback that yields tokens class StreamingGeneratorCallback(BaseCallbackHandler): def __init__(self): self.buffer = "" self.queue = [] self.streaming = True def on_llm_new_token(self, token: str, **kwargs): self.queue.append(token) def stream(self) -> Generator[str, None, None]: while self.streaming or self.queue: if self.queue: chunk = self.queue.pop(0) yield chunk @app.get("/") def root(): return {"message": "FastAPI Ollama is running."} @app.post("/ask") def ask(query: QueryWithContext): prompt = f"""You are an expert in quantitative methods. Based on the following lecture notes, answer the user's question. Lecture notes: {query.context} User's question: {query.question} """ # Create streaming callback stream_callback = StreamingGeneratorCallback() callback_manager = CallbackManager([stream_callback]) llm = Ollama(model=query.model, callback_manager=callback_manager) try: # Start generation in background def run_llm(): try: llm.invoke(prompt) finally: stream_callback.streaming = False import threading threading.Thread(target=run_llm).start() return StreamingResponse(stream_callback.stream(), media_type="text/plain") except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.on_event("startup") async def startup_event(): logger.info("FastAPI is starting up...") @app.on_event("shutdown") async def shutdown_event(): logger.info("FastAPI is shutting down.")