Spaces:
Paused
Paused
| import os | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, Response, Body, Security | |
| from fastapi.security import APIKeyHeader | |
| from pydantic import BaseModel, model_validator | |
| from typing import List | |
| import json | |
| from conversation.conversation_store import ConversationStore | |
| from rag_langchain import LangChainRAG | |
| load_dotenv() | |
| api_keys = [os.environ["API_API_KEY"]] | |
| api = FastAPI() | |
| conversation_store = ConversationStore() | |
| api_key_header = APIKeyHeader(name="Authorization", auto_error=True) | |
| prompt_id = "summarize_rag_1" | |
| check_prompt_id = "check_control_challenge_step_back" | |
| rewrite_prompt_id = "first" | |
| default_llm = "gpt-4o 128k" | |
| class QModel(BaseModel): | |
| q: str | |
| retrieval_count: int = 10 | |
| temperature: str = "0.2" | |
| llm: str = default_llm | |
| def validate_to_json(cls, value): | |
| if isinstance(value, str): | |
| return cls(**json.loads(value)) | |
| return value | |
| class AModel(BaseModel): | |
| q: str | |
| a: str | |
| sources: List[str] | |
| oid: str | |
| class EmoModel(BaseModel): | |
| qid: str | |
| helpfulness: str | |
| def validate_to_json(cls, value): | |
| if isinstance(value, str): | |
| return cls(**json.loads(value)) | |
| return value | |
| async def read_root(): | |
| return "Empty" | |
| async def qa(api_key: str = Security(api_key_header), data: QModel = Body(...)): | |
| if not valid_api_key(api_key): | |
| return Response(status_code=401) | |
| rag = LangChainRAG( | |
| config={ | |
| "retrieve_documents": data.retrieval_count, | |
| "temperature": data.temperature, | |
| "prompt_id": prompt_id, | |
| "check_prompt_id": check_prompt_id, | |
| "rewrite_prompt_id": rewrite_prompt_id | |
| } | |
| ) | |
| answer, check_result, sources = await rag.rag_chain(data.q, data.llm) | |
| oid = conversation_store.save_content( | |
| q=data.q, | |
| a=answer, | |
| sources=list(map(lambda doc: doc.page_content, sources)), | |
| params= | |
| { | |
| "prompt_id": prompt_id, | |
| "check_prompt_id": check_prompt_id, | |
| "rewrite_prompt_id": rewrite_prompt_id, | |
| "check_result": check_result, | |
| "temperature": data.temperature, | |
| "retrieve_document_count": str(data.retrieval_count), | |
| } | |
| ) | |
| return AModel( | |
| a=answer, | |
| q=data.q, | |
| sources=list(map(lambda doc: doc.page_content, sources)), | |
| oid=oid | |
| ) | |
| async def emo(api_key: str = Security(api_key_header), json_body: EmoModel = Body(...)): | |
| if not valid_api_key(api_key): | |
| return Response(status_code=401) | |
| conversation = conversation_store.get(json_body.qid) | |
| new_params = conversation.params | |
| new_params["user_grading"] = str(json_body.helpfulness) | |
| conversation_store.update( | |
| oid=json_body["qid"], | |
| q=conversation.conversation[0].q, | |
| a=conversation.conversation[0].a, | |
| sources=conversation.conversation[0].sources, | |
| params=new_params | |
| ) | |
| def valid_api_key(api_key: str) -> bool: | |
| return api_key in api_keys | |