Spaces:
Sleeping
Sleeping
| import logging | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import StreamingResponse | |
| from pydantic import BaseModel | |
| from typing import Generator | |
| from langchain_community.llms import Ollama | |
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
| from langchain.callbacks.base import BaseCallbackHandler | |
| from langchain.callbacks.manager import CallbackManager | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| # Enable CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # You can restrict this in production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class QueryWithContext(BaseModel): | |
| question: str | |
| context: str | |
| model: str = "tinyllama" # default model | |
| # Streaming callback that yields tokens | |
| class StreamingGeneratorCallback(BaseCallbackHandler): | |
| def __init__(self): | |
| self.buffer = "" | |
| self.queue = [] | |
| self.streaming = True | |
| def on_llm_new_token(self, token: str, **kwargs): | |
| self.queue.append(token) | |
| def stream(self) -> Generator[str, None, None]: | |
| while self.streaming or self.queue: | |
| if self.queue: | |
| chunk = self.queue.pop(0) | |
| yield chunk | |
| def root(): | |
| return {"message": "FastAPI Ollama is running."} | |
| def ask(query: QueryWithContext): | |
| prompt = f"""You are an expert in quantitative methods. Based on the following lecture notes, answer the user's question. | |
| Lecture notes: | |
| {query.context} | |
| User's question: | |
| {query.question} | |
| """ | |
| # Create streaming callback | |
| stream_callback = StreamingGeneratorCallback() | |
| callback_manager = CallbackManager([stream_callback]) | |
| llm = Ollama(model=query.model, callback_manager=callback_manager) | |
| try: | |
| # Start generation in background | |
| def run_llm(): | |
| try: | |
| llm.invoke(prompt) | |
| finally: | |
| stream_callback.streaming = False | |
| import threading | |
| threading.Thread(target=run_llm).start() | |
| return StreamingResponse(stream_callback.stream(), media_type="text/plain") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def startup_event(): | |
| logger.info("FastAPI is starting up...") | |
| async def shutdown_event(): | |
| logger.info("FastAPI is shutting down.") | |