Spaces:
Paused
Paused
Commit ·
1bc3f18
1
Parent(s): 9676e57
1st
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .env +115 -0
- .gitignore +17 -0
- Dockerfile +36 -0
- README.md +21 -4
- celery_app.py +30 -0
- config.py +64 -0
- docker-entrypoint.sh +14 -0
- generation/AssistantRagGenerator.py +201 -0
- generation/ExamAnswer.py +314 -0
- generation/ExamRagGenerator.py +460 -0
- generation/__init__.py +0 -0
- generation/answer_models.py +51 -0
- generation/parsing_utils.py +51 -0
- generation/prompts.py +250 -0
- indexing/indexingController.py +111 -0
- ingestion/chunkers/__init__.py +0 -0
- ingestion/chunkers/fixed_chunker.py +10 -0
- ingestion/chunkers/recursive_chunker.py +9 -0
- ingestion/loaders/File_loader.py +57 -0
- ingestion/loaders/__init__.py +1 -0
- ingestion/loaders/docx_loader.py +89 -0
- ingestion/loaders/md_loader.py +48 -0
- ingestion/loaders/normalization.py +35 -0
- ingestion/loaders/pdf_loader.py +66 -0
- ingestion/loaders/txt_loader.py +38 -0
- ingestion/pdf_outline.py +62 -0
- main.py +20 -0
- requirements.txt +19 -0
- routes/__init__.py +0 -0
- routes/assisstant_rag.py +165 -0
- routes/base.py +45 -0
- routes/exam_grading_router.py +122 -0
- routes/exam_router.py +15 -0
- routes/schemas/Exam_Models.py +180 -0
- routes/schemas/Requests_Models.py +24 -0
- routes/schemas/__init__.py +0 -0
- stores/llm/LLMEnums.py +28 -0
- stores/llm/LLMInterface.py +24 -0
- stores/llm/LLMProviderFactory.py +93 -0
- stores/llm/__init__.py +0 -0
- stores/llm/providers/CohereProvider.py +395 -0
- stores/llm/providers/DeepSeekProvider.py +126 -0
- stores/llm/providers/GeminiProvider.py +305 -0
- stores/llm/providers/GroqProvider.py +133 -0
- stores/llm/providers/HuggingFaceProvider.py +214 -0
- stores/llm/providers/MistralProvider.py +208 -0
- stores/llm/providers/OllamaProvider.py +292 -0
- stores/llm/providers/OpenAIProvider.py +102 -0
- stores/llm/providers/OpenRouterProvider.py +179 -0
- stores/llm/providers/__init__.py +0 -0
.env
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
APP_NAME="IntegraRAG"
|
| 2 |
+
DEBUG=False
|
| 3 |
+
CustomLoaders=False
|
| 4 |
+
|
| 5 |
+
# ---------- QDRANT ---------- Choose One
|
| 6 |
+
# QDRANT_TYPE="local"
|
| 7 |
+
# QDRANT_DOCKER_URL=""
|
| 8 |
+
# QDRANT_API_KEY=""
|
| 9 |
+
|
| 10 |
+
# QDRANT_TYPE="docker"
|
| 11 |
+
# QDRANT_DOCKER_URL="http://localhost:6333/"
|
| 12 |
+
# QDRANT_API_KEY=""
|
| 13 |
+
|
| 14 |
+
QDRANT_TYPE="cloud"
|
| 15 |
+
QDRANT_DOCKER_URL="https://d7e287d8-903d-436c-854c-03cbef9e4edb.us-east4-0.gcp.cloud.qdrant.io"
|
| 16 |
+
QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.NRbT0QPl7isuBKvdtganh89xa2DeMgKXZ3gSJngexQg"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ---------- REDIS ----------
|
| 20 |
+
REDIS_HOST="rediss://default:gQAAAAAAAS-BAAIncDFiM2E3OGQ1MmU5Zjk0OGM5ODU2ZmMzYzc4NjZjYzdjMHAxNzc2OTc@steady-clam-77697.upstash.io"
|
| 21 |
+
REDIS_PORT=6379
|
| 22 |
+
|
| 23 |
+
# ---------- WEBHOOKS ----------
|
| 24 |
+
CALLBACK_URL="https://webhooksite.net/c93aac48-5237-4078-9511-14d778acba2f"
|
| 25 |
+
GRADE_WEBHOOK_URL="https://webhooksite.net/c93aac48-5237-4078-9511-14d778acba2f"
|
| 26 |
+
|
| 27 |
+
# ---------- BACKENDS ---------- Choose One
|
| 28 |
+
#generation
|
| 29 |
+
# OLLAMA | COHERE | MISTRAL | GEMINI | HUGGINGFACE | GROQ | OPENROUTER | DEEPSEEK |
|
| 30 |
+
#embedding
|
| 31 |
+
# OLLAMA | COHERE | MISTRAL | GEMINI | HUGGINGFACE
|
| 32 |
+
|
| 33 |
+
# ---------- OLLAMA ----------
|
| 34 |
+
OLLAMA_URL="http://localhost:11434"
|
| 35 |
+
# OLLAMA_API_KEY="getAone"
|
| 36 |
+
# GENERATION_BACKEND="OLLAMA"
|
| 37 |
+
# EMBEDDING_BACKEND="OLLAMA"
|
| 38 |
+
# GENERATION_MODEL_ID="deepseek-v3.1:671b-cloud"
|
| 39 |
+
# EMBEDDING_MODEL_ID="embeddinggemma:latest"
|
| 40 |
+
# EMBEDDING_MODEL_SIZE=768
|
| 41 |
+
# QDRANT_COLLECTION="768_docs"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ---------- COHERE ----------
|
| 45 |
+
COHERE_API_KEY="getAone"
|
| 46 |
+
# GENERATION_BACKEND="COHERE"
|
| 47 |
+
# EMBEDDING_BACKEND="COHERE"
|
| 48 |
+
# GENERATION_MODEL_ID="command-a-03-2025"
|
| 49 |
+
# EMBEDDING_MODEL_ID="embed-multilingual-v3.0"
|
| 50 |
+
# EMBEDDING_MODEL_SIZE=1024
|
| 51 |
+
# QDRANT_COLLECTION="1024_docs"
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ---------- MISTRAL ----------
|
| 55 |
+
MISTRAL_API_KEY="getAone"
|
| 56 |
+
# GENERATION_BACKEND="MISTRAL"
|
| 57 |
+
# EMBEDDING_BACKEND="MISTRAL"
|
| 58 |
+
# GENERATION_MODEL_ID="mistral-small-2603"
|
| 59 |
+
# EMBEDDING_MODEL_ID="mistral-embed-2312"
|
| 60 |
+
# EMBEDDING_MODEL_SIZE=1024
|
| 61 |
+
# QDRANT_COLLECTION="1024_docs"
|
| 62 |
+
|
| 63 |
+
# ---------- GEMINI ----------
|
| 64 |
+
GEMINI_API_KEY="getAone"
|
| 65 |
+
GENERATION_BACKEND="GEMINI"
|
| 66 |
+
EMBEDDING_BACKEND="GEMINI"
|
| 67 |
+
GENERATION_MODEL_ID="gemini-2.5-flash"
|
| 68 |
+
EMBEDDING_MODEL_ID="gemini-embedding-001"
|
| 69 |
+
EMBEDDING_MODEL_SIZE=768
|
| 70 |
+
QDRANT_COLLECTION="768_docs"
|
| 71 |
+
|
| 72 |
+
# ---------- HUGGING FACE ----------
|
| 73 |
+
HF_API_KEY="getAone"
|
| 74 |
+
# GENERATION_BACKEND="HUGGINGFACE"
|
| 75 |
+
# EMBEDDING_BACKEND="HUGGINGFACE"
|
| 76 |
+
# GENERATION_MODEL_ID="Qwen/Qwen2.5-72B-Instruct"
|
| 77 |
+
# EMBEDDING_MODEL_ID="google/embeddinggemma-300m"
|
| 78 |
+
# EMBEDDING_MODEL_SIZE=768
|
| 79 |
+
# QDRANT_COLLECTION="768_docs"
|
| 80 |
+
|
| 81 |
+
# ---------- DEEPSEEK ---------- paid
|
| 82 |
+
DEEPSEEK_API_KEY="getAone"
|
| 83 |
+
# GENERATION_BACKEND="DEEPSEEK"
|
| 84 |
+
# EMBEDDING_BACKEND="COHERE"
|
| 85 |
+
# GENERATION_MODEL_ID="deepseek-chat"
|
| 86 |
+
# EMBEDDING_MODEL_ID="embed-multilingual-v3.0"
|
| 87 |
+
# EMBEDDING_MODEL_SIZE=1024
|
| 88 |
+
# QDRANT_COLLECTION="1024_docs"
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ---------- OPENAI ---------- paid
|
| 92 |
+
OPENAI_API_KEY=""
|
| 93 |
+
OPENAI_API_URL=""
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ---------- GROQ ----------not complete
|
| 97 |
+
GROQ_API_KEY=""
|
| 98 |
+
|
| 99 |
+
# ---------- OPENROUTER ----------not complete
|
| 100 |
+
OPENROUTER_API_KEY=""
|
| 101 |
+
OPENROUTER_SITE_URL="http://localhost"
|
| 102 |
+
OPENROUTER_APP_NAME="IntegraRAG"
|
| 103 |
+
OPENROUTER_SEARCH_MODEL="perplexity/sonar-online"
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# ---------- DEFAULTS ----------
|
| 108 |
+
INPUT_DAFAULT_MAX_CHARACTERS=2048
|
| 109 |
+
GENERATION_DAFAULT_MAX_TOKENS=1200
|
| 110 |
+
GENERATION_DAFAULT_TEMPERATURE=0.3
|
| 111 |
+
|
| 112 |
+
# ---------- CHUNKING ----------
|
| 113 |
+
CHUNK_SIZE=700
|
| 114 |
+
CHUNK_OVERLAP=150
|
| 115 |
+
CHUNK_METHOD="recursive"
|
.gitignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python virtual environment
|
| 2 |
+
venv/
|
| 3 |
+
.env/
|
| 4 |
+
__pycache__/
|
| 5 |
+
*.pyc
|
| 6 |
+
*.pyo
|
| 7 |
+
*.pyd
|
| 8 |
+
|
| 9 |
+
Code_Backups.txt
|
| 10 |
+
data/
|
| 11 |
+
# VSCode
|
| 12 |
+
.vscode/
|
| 13 |
+
.vs/
|
| 14 |
+
|
| 15 |
+
# OS
|
| 16 |
+
.DS_Store
|
| 17 |
+
Thumbs.db
|
Dockerfile
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ─────────────────────────────────────────────
|
| 2 |
+
# IntegraRAG — Production Dockerfile
|
| 3 |
+
# Services bundled: FastAPI + Celery worker
|
| 4 |
+
# External deps: Redis, Qdrant (cloud/managed)
|
| 5 |
+
# ─────────────────────────────────────────────
|
| 6 |
+
FROM python:3.11-slim
|
| 7 |
+
|
| 8 |
+
# System deps
|
| 9 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 10 |
+
build-essential \
|
| 11 |
+
curl \
|
| 12 |
+
libmagic1 \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
WORKDIR /app
|
| 16 |
+
|
| 17 |
+
# Install Python deps first (layer cache)
|
| 18 |
+
COPY requirements.txt .
|
| 19 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 20 |
+
|
| 21 |
+
# Copy application source
|
| 22 |
+
COPY . .
|
| 23 |
+
|
| 24 |
+
# ── Runtime env defaults (override via HF Secrets or docker run -e) ──
|
| 25 |
+
ENV PORT=7860 \
|
| 26 |
+
PYTHONUNBUFFERED=1 \
|
| 27 |
+
PYTHONDONTWRITEBYTECODE=1
|
| 28 |
+
|
| 29 |
+
# Hugging Face Spaces exposes port 7860
|
| 30 |
+
EXPOSE 7860
|
| 31 |
+
|
| 32 |
+
# Entrypoint: start Celery worker in background, then FastAPI
|
| 33 |
+
COPY docker-entrypoint.sh /docker-entrypoint.sh
|
| 34 |
+
RUN chmod +x /docker-entrypoint.sh
|
| 35 |
+
|
| 36 |
+
ENTRYPOINT ["/docker-entrypoint.sh"]
|
README.md
CHANGED
|
@@ -1,10 +1,27 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: blue
|
| 6 |
sdk: docker
|
|
|
|
| 7 |
pinned: false
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: IntegraRAG API
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: indigo
|
| 5 |
colorTo: blue
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
pinned: false
|
| 9 |
---
|
| 10 |
|
| 11 |
+
# IntegraRAG — RAG-Powered Exam & Assistant API
|
| 12 |
+
|
| 13 |
+
FastAPI backend with Celery workers for document-based Q&A, exam generation, and AI-graded exam submissions.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
`conda create -n RAG_API python==3.11`
|
| 17 |
+
`conda activate RAG_API`
|
| 18 |
+
`pip install -r requirements.txt`
|
| 19 |
+
|
| 20 |
+
`docker run -p 6333:6333 qdrant/qdrant`
|
| 21 |
+
`docker run -d -p 6379:6379 redis:7`
|
| 22 |
+
|
| 23 |
+
# View The .env
|
| 24 |
+
|
| 25 |
+
`celery -A celery_app.celery_app worker -P threads --loglevel=info`
|
| 26 |
+
`uvicorn main:app --host 0.0.0.0 --port 8030 --reload`
|
| 27 |
+
`uvicorn webhook:app --reload`
|
celery_app.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# celery_app.py
|
| 2 |
+
from celery import Celery
|
| 3 |
+
import redis
|
| 4 |
+
from config import get_settings
|
| 5 |
+
|
| 6 |
+
celery_app = Celery(
|
| 7 |
+
"assistant_worker",
|
| 8 |
+
broker=f"{get_settings().REDIS_HOST}:{get_settings().REDIS_PORT}/0",
|
| 9 |
+
backend=f"{get_settings().REDIS_HOST}:{get_settings().REDIS_PORT}/1",
|
| 10 |
+
include=['generation.ExamAnswer']
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
celery_app.conf.update(
|
| 14 |
+
task_serializer="json",
|
| 15 |
+
accept_content=["json"],
|
| 16 |
+
result_serializer="json",
|
| 17 |
+
task_track_started=True,
|
| 18 |
+
task_time_limit=60*60,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
import worker.tasks
|
| 22 |
+
from generation.ExamAnswer import grade_exam_task
|
| 23 |
+
def clear_redis_backend():
|
| 24 |
+
r = redis.Redis(host=get_settings().REDIS_HOST, port=get_settings().REDIS_PORT, db=1)
|
| 25 |
+
r.flushdb()
|
| 26 |
+
print("Redis result backend cleared!")
|
| 27 |
+
|
| 28 |
+
@celery_app.on_after_configure.connect
|
| 29 |
+
def setup(sender, **kwargs):
|
| 30 |
+
clear_redis_backend()
|
config.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic_settings import BaseSettings
|
| 2 |
+
from functools import lru_cache
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Settings(BaseSettings):
|
| 6 |
+
DEBUG: bool = False
|
| 7 |
+
APP_NAME: str
|
| 8 |
+
QDRANT_COLLECTION: str = "docs"
|
| 9 |
+
|
| 10 |
+
CustomLoaders: bool = None
|
| 11 |
+
QDRANT_TYPE: str = "docker"
|
| 12 |
+
QDRANT_DOCKER_URL: str = "http://localhost:6333"
|
| 13 |
+
QDRANT_API_KEY: str = None
|
| 14 |
+
CHUNK_SIZE: int = 1000
|
| 15 |
+
CHUNK_OVERLAP: int = None
|
| 16 |
+
CHUNK_METHOD: str = None
|
| 17 |
+
GRADE_WEBHOOK_URL: str = None
|
| 18 |
+
REDIS_HOST: str = "localhost"
|
| 19 |
+
REDIS_PORT: int = 6379
|
| 20 |
+
CALLBACK_URL: str = None
|
| 21 |
+
|
| 22 |
+
# ---------- BACKENDS ----------
|
| 23 |
+
GENERATION_BACKEND: str = "OLLAMA"
|
| 24 |
+
EMBEDDING_BACKEND: str = "OLLAMA"
|
| 25 |
+
|
| 26 |
+
# ---------- API KEYS ----------
|
| 27 |
+
OPENAI_API_KEY: str = None
|
| 28 |
+
OPENAI_API_URL: str = None
|
| 29 |
+
|
| 30 |
+
COHERE_API_KEY: str = None
|
| 31 |
+
|
| 32 |
+
OLLAMA_URL: str = "http://localhost:11434"
|
| 33 |
+
OLLAMA_API_KEY: str = None
|
| 34 |
+
|
| 35 |
+
MISTRAL_API_KEY: str = None
|
| 36 |
+
|
| 37 |
+
GROQ_API_KEY: str = None
|
| 38 |
+
|
| 39 |
+
OPENROUTER_API_KEY: str = None
|
| 40 |
+
OPENROUTER_SITE_URL: str = "http://localhost" # forwarded as HTTP-Referer
|
| 41 |
+
OPENROUTER_APP_NAME: str = "IntegraRAG" # forwarded as X-Title
|
| 42 |
+
OPENROUTER_SEARCH_MODEL: str = "perplexity/sonar-online"
|
| 43 |
+
|
| 44 |
+
HF_API_KEY: str = None
|
| 45 |
+
|
| 46 |
+
DEEPSEEK_API_KEY: str = None
|
| 47 |
+
|
| 48 |
+
GEMINI_API_KEY: str = None
|
| 49 |
+
|
| 50 |
+
# ---------- MODELS ----------
|
| 51 |
+
GENERATION_MODEL_ID: str = "deepseek-v3.1:671b-cloud"
|
| 52 |
+
EMBEDDING_MODEL_ID: str = "embeddinggemma:latest"
|
| 53 |
+
EMBEDDING_MODEL_SIZE: int = 768
|
| 54 |
+
INPUT_DAFAULT_MAX_CHARACTERS: int = None
|
| 55 |
+
GENERATION_DAFAULT_MAX_TOKENS: int = None
|
| 56 |
+
GENERATION_DAFAULT_TEMPERATURE: float = None
|
| 57 |
+
|
| 58 |
+
class Config:
|
| 59 |
+
env_file = ".env"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@lru_cache
|
| 63 |
+
def get_settings():
|
| 64 |
+
return Settings()
|
docker-entrypoint.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
echo "==> Starting Celery worker in background..."
|
| 5 |
+
celery -A celery_app.celery_app worker \
|
| 6 |
+
-P threads \
|
| 7 |
+
--loglevel=info \
|
| 8 |
+
--concurrency=4 &
|
| 9 |
+
|
| 10 |
+
echo "==> Starting FastAPI (uvicorn) on port ${PORT:-7860}..."
|
| 11 |
+
exec uvicorn main:app \
|
| 12 |
+
--host 0.0.0.0 \
|
| 13 |
+
--port "${PORT:-7860}" \
|
| 14 |
+
--workers 1
|
generation/AssistantRagGenerator.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
from pydantic import Field
|
| 3 |
+
from langchain_core.language_models import LLM
|
| 4 |
+
from langchain_core.runnables import RunnableBranch, RunnableLambda, RunnableParallel
|
| 5 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 6 |
+
from langchain_core.prompts import PromptTemplate
|
| 7 |
+
from stores.llm.LLMProviderFactory import LLMProviderFactory
|
| 8 |
+
from config import get_settings
|
| 9 |
+
|
| 10 |
+
class ProviderLLMWrapper(LLM):
|
| 11 |
+
provider: Any = Field(..., description="The wrapped LLM provider")
|
| 12 |
+
def _call(self, prompt: str, stop=None) -> str:
|
| 13 |
+
# Calls the underlying model and ensures a string is returned
|
| 14 |
+
result = self.provider.generate_text(prompt)
|
| 15 |
+
if result is None:
|
| 16 |
+
raise ValueError("LLM provider returned None (likely due to timeout or error)")
|
| 17 |
+
if isinstance(result, dict):
|
| 18 |
+
response = result.get("response")
|
| 19 |
+
if response is None:
|
| 20 |
+
raise ValueError(f"LLM provider returned dict without 'response' key: {result.keys()}")
|
| 21 |
+
return response
|
| 22 |
+
if isinstance(result, str):
|
| 23 |
+
return result
|
| 24 |
+
raise ValueError(f"Unexpected LLM response type: {type(result).__name__}")
|
| 25 |
+
@property
|
| 26 |
+
def _llm_type(self):
|
| 27 |
+
return "custom-provider"
|
| 28 |
+
|
| 29 |
+
def get_num_tokens(self, text: str) -> int:
|
| 30 |
+
return len(text.split())
|
| 31 |
+
|
| 32 |
+
class AssistantRagGen:
|
| 33 |
+
def __init__(self):
|
| 34 |
+
config = get_settings()
|
| 35 |
+
self.factory = LLMProviderFactory(config)
|
| 36 |
+
self.generator = self.factory.create(config.GENERATION_BACKEND)
|
| 37 |
+
self.generator.set_generation_model(config.GENERATION_MODEL_ID)
|
| 38 |
+
self.llm = ProviderLLMWrapper(provider=self.generator)
|
| 39 |
+
self.valid_routes = {"user_info", "site_query", "pdf_query"}
|
| 40 |
+
|
| 41 |
+
def build_router_prompt(self, user_prompt: str) -> str:
|
| 42 |
+
return f"""You are a query routing classifier. Your sole job is to categorize a user's question into exactly one routing category.
|
| 43 |
+
|
| 44 |
+
## Categories
|
| 45 |
+
|
| 46 |
+
| Category | Routes questions about... |
|
| 47 |
+
|--------------|-------------------------------------------------------------------------------------------|
|
| 48 |
+
| `user_info` | Personal profile, enrolled courses, username, role, learning progress, achievements |
|
| 49 |
+
| `site_query` | Platform features, website navigation, rules, policies, FAQs, general platform knowledge |
|
| 50 |
+
| `pdf_query` | Document content, uploaded files, PDF search, lesson materials, reading resources |
|
| 51 |
+
|
| 52 |
+
## Examples
|
| 53 |
+
|
| 54 |
+
user_info → "What courses am I enrolled in?"
|
| 55 |
+
user_info → "What is my current progress in the Python course?"
|
| 56 |
+
site_query → "How do I reset my password?"
|
| 57 |
+
site_query → "What are the platform's refund policies?"
|
| 58 |
+
pdf_query → "What does the document say about recursion?"
|
| 59 |
+
pdf_query → "Find me the section on neural networks in the materials"
|
| 60 |
+
|
| 61 |
+
## Decision Rules
|
| 62 |
+
|
| 63 |
+
1. If the question involves the **current user's personal data** → `user_info`
|
| 64 |
+
2. If the question is about **how the platform works** → `site_query`
|
| 65 |
+
3. If the question requires **reading or searching a document** → `pdf_query`
|
| 66 |
+
4. When ambiguous, prefer `pdf_query` over `site_query`, and `user_info` over both.
|
| 67 |
+
|
| 68 |
+
## Output Format
|
| 69 |
+
|
| 70 |
+
Respond with a single lowercase word. No punctuation. No explanation. No whitespace.
|
| 71 |
+
|
| 72 |
+
Valid outputs: user_info | site_query | pdf_query
|
| 73 |
+
|
| 74 |
+
Question: {user_prompt}
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def build_unified_prompt(self, context: str, question: str, conversation_history: str = "", User_Info: str = "") -> str:
|
| 78 |
+
return f"""
|
| 79 |
+
You are a helpful university assistant.
|
| 80 |
+
|
| 81 |
+
Rules:
|
| 82 |
+
- Use the provided context FIRST.
|
| 83 |
+
- Use conversation history to understand follow-up questions.
|
| 84 |
+
- If the question is about the user, use the User_Info and enrolled_courses.
|
| 85 |
+
- If the answer is not in the context, say:
|
| 86 |
+
"Not found in the provided materials."
|
| 87 |
+
Then add:
|
| 88 |
+
"From my own information:" and answer briefly.
|
| 89 |
+
- Be concise and clear.
|
| 90 |
+
|
| 91 |
+
Conversation History:
|
| 92 |
+
{conversation_history if conversation_history else "None"}
|
| 93 |
+
|
| 94 |
+
User Info:
|
| 95 |
+
{User_Info if User_Info else "None"}
|
| 96 |
+
|
| 97 |
+
Context:
|
| 98 |
+
{context}
|
| 99 |
+
|
| 100 |
+
Current Question:
|
| 101 |
+
{question}
|
| 102 |
+
|
| 103 |
+
Answer:
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def build_user_info_prompt(self, question: str, conversation_history: str = "", User_Info: str = "") -> str:
|
| 107 |
+
return f"""
|
| 108 |
+
You are a university assistant handling a user account inquiry.
|
| 109 |
+
Use the provided User Info and Enrolled Courses to answer the question accurately.
|
| 110 |
+
|
| 111 |
+
Conversation History:
|
| 112 |
+
{conversation_history if conversation_history else "None"}
|
| 113 |
+
|
| 114 |
+
User Info:
|
| 115 |
+
{User_Info if User_Info else "None"}
|
| 116 |
+
|
| 117 |
+
Current Question:
|
| 118 |
+
{question}
|
| 119 |
+
|
| 120 |
+
Answer:
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def build_site_query_prompt(self, question: str,context:str="", conversation_history: str = "") -> str:
|
| 124 |
+
return f"""
|
| 125 |
+
You are a university assistant handling a platform or site-related question.
|
| 126 |
+
Provide clear instructions, rules, or general information about how the university platform works.
|
| 127 |
+
|
| 128 |
+
Conversation History:
|
| 129 |
+
{conversation_history if conversation_history else "None"}
|
| 130 |
+
|
| 131 |
+
Current Question:
|
| 132 |
+
{question}
|
| 133 |
+
|
| 134 |
+
Site Context:
|
| 135 |
+
{context if context else "None"}
|
| 136 |
+
|
| 137 |
+
Answer:
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def robust_router(self, input_data: dict) -> str:
|
| 141 |
+
question = input_data["question"]
|
| 142 |
+
attempts = 0
|
| 143 |
+
while attempts < 3:
|
| 144 |
+
prompt = self.build_router_prompt(question)
|
| 145 |
+
route = self.llm.invoke(prompt).strip().lower()
|
| 146 |
+
|
| 147 |
+
if route in self.valid_routes:
|
| 148 |
+
return route
|
| 149 |
+
attempts += 1
|
| 150 |
+
return "pdf_query"
|
| 151 |
+
|
| 152 |
+
def get_chain(self):
|
| 153 |
+
router_node = RunnableLambda(self.robust_router)
|
| 154 |
+
|
| 155 |
+
user_info_chain = RunnableLambda(lambda x: self.llm.invoke(
|
| 156 |
+
self.build_user_info_prompt(
|
| 157 |
+
question=x["question"],
|
| 158 |
+
conversation_history=x.get("conversation_history", ""),
|
| 159 |
+
User_Info=x.get("User_Info", ""),
|
| 160 |
+
)
|
| 161 |
+
))
|
| 162 |
+
|
| 163 |
+
site_query_chain = RunnableLambda(lambda x: self.llm.invoke(
|
| 164 |
+
self.build_site_query_prompt(
|
| 165 |
+
question=x["question"],
|
| 166 |
+
context=x.get("context", ""),
|
| 167 |
+
conversation_history=x.get("conversation_history", "")
|
| 168 |
+
)
|
| 169 |
+
))
|
| 170 |
+
|
| 171 |
+
pdf_query_chain = RunnableLambda(lambda x: self.llm.invoke(
|
| 172 |
+
self.build_unified_prompt(
|
| 173 |
+
context=x.get("context", "No context provided."),
|
| 174 |
+
question=x["question"],
|
| 175 |
+
conversation_history=x.get("conversation_history", ""),
|
| 176 |
+
User_Info=x.get("User_Info", ""),
|
| 177 |
+
)
|
| 178 |
+
))
|
| 179 |
+
|
| 180 |
+
branching_logic = RunnableBranch(
|
| 181 |
+
(lambda x: x["topic"] == "user_info", user_info_chain),
|
| 182 |
+
(lambda x: x["topic"] == "site_query", site_query_chain),
|
| 183 |
+
pdf_query_chain
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
full_chain = (
|
| 187 |
+
RunnableParallel({
|
| 188 |
+
"topic": router_node,
|
| 189 |
+
# Pass all incoming variables straight through to the branches
|
| 190 |
+
"question": lambda x: x["question"],
|
| 191 |
+
"context": lambda x: x.get("context", ""),
|
| 192 |
+
"conversation_history": lambda x: x.get("conversation_history", ""),
|
| 193 |
+
"User_Info": lambda x: x.get("User_Info", ""),
|
| 194 |
+
"enrolled_courses": lambda x: x.get("enrolled_courses", "")
|
| 195 |
+
})
|
| 196 |
+
| branching_logic
|
| 197 |
+
| StrOutputParser()
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
return full_chain
|
| 201 |
+
|
generation/ExamAnswer.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
from typing import List, Dict, Any
|
| 4 |
+
from celery import shared_task
|
| 5 |
+
import json
|
| 6 |
+
import re
|
| 7 |
+
import httpx
|
| 8 |
+
|
| 9 |
+
from generation.answer_models import (ExamSubmission,ExamResult,StudentAnswer,GradedAnswer,QuestionType)
|
| 10 |
+
from indexing.indexingController import IndexingController
|
| 11 |
+
from stores.llm.LLMProviderFactory import LLMProviderFactory
|
| 12 |
+
from config import get_settings
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def calculate_grade(percentage: float) -> str:
|
| 16 |
+
if percentage >= 90:
|
| 17 |
+
return "A"
|
| 18 |
+
elif percentage >= 80:
|
| 19 |
+
return "B"
|
| 20 |
+
elif percentage >= 70:
|
| 21 |
+
return "C"
|
| 22 |
+
elif percentage >= 60:
|
| 23 |
+
return "D"
|
| 24 |
+
else:
|
| 25 |
+
return "F"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
class ExamGradingService:
|
| 31 |
+
def __init__(self, use_ai_for_essays: bool = True):
|
| 32 |
+
self.use_ai_for_essays = use_ai_for_essays
|
| 33 |
+
|
| 34 |
+
config = get_settings()
|
| 35 |
+
|
| 36 |
+
factory = LLMProviderFactory(config)
|
| 37 |
+
provider = factory.create(config.GENERATION_BACKEND)
|
| 38 |
+
provider.set_generation_model(config.GENERATION_MODEL_ID)
|
| 39 |
+
self.llm = provider
|
| 40 |
+
|
| 41 |
+
self.semantic_threshold = 0.65
|
| 42 |
+
self.high_confidence = 0.85
|
| 43 |
+
|
| 44 |
+
def grade_submission(self, submission: ExamSubmission) -> ExamResult:
|
| 45 |
+
graded_answers: List[GradedAnswer] = []
|
| 46 |
+
total_score = 0
|
| 47 |
+
max_total_score = 0
|
| 48 |
+
|
| 49 |
+
for ans in submission.answers:
|
| 50 |
+
correct_answer = None
|
| 51 |
+
if ans.metadata:
|
| 52 |
+
correct_answer = ans.metadata.get("correct_answer")
|
| 53 |
+
|
| 54 |
+
graded = self.grade_answer(ans, correct_answer,submission.course_id)
|
| 55 |
+
graded_answers.append(graded)
|
| 56 |
+
total_score += graded.score
|
| 57 |
+
max_total_score += graded.max_score
|
| 58 |
+
|
| 59 |
+
percentage = (total_score / max_total_score) * 100 if max_total_score else 0
|
| 60 |
+
grade = calculate_grade(percentage)
|
| 61 |
+
|
| 62 |
+
return ExamResult(
|
| 63 |
+
exam_id=submission.exam_id,
|
| 64 |
+
student_id=submission.student_id,
|
| 65 |
+
student_name=submission.student_name,
|
| 66 |
+
graded_answers=graded_answers,
|
| 67 |
+
total_score=total_score,
|
| 68 |
+
max_total_score=max_total_score,
|
| 69 |
+
percentage=percentage,
|
| 70 |
+
grade=grade,
|
| 71 |
+
feedback_summary="RAG based grading using LLM evaluation",
|
| 72 |
+
submission_time=submission.submission_time,
|
| 73 |
+
graded_time=datetime.utcnow().isoformat()
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def grade_answer(self, answer: StudentAnswer, correct_answer: Any, course) -> GradedAnswer:
|
| 77 |
+
if answer.question_type in [QuestionType.MULTIPLE_CHOICE,QuestionType.TRUE_FALSE]:
|
| 78 |
+
student_str = str(answer.student_response).strip().lower()
|
| 79 |
+
if answer.question_type == QuestionType.TRUE_FALSE:
|
| 80 |
+
if isinstance(correct_answer, bool):
|
| 81 |
+
correct_bool = correct_answer
|
| 82 |
+
elif isinstance(correct_answer, str):
|
| 83 |
+
correct_bool = correct_answer.lower() in ['true', 't', '1', 'yes', 'True']
|
| 84 |
+
else:
|
| 85 |
+
correct_bool = bool(correct_answer)
|
| 86 |
+
student_bool = student_str in ['true', 't', '1', 'yes']
|
| 87 |
+
is_correct = student_bool == correct_bool
|
| 88 |
+
score = answer.max_score if is_correct else 0
|
| 89 |
+
feedback = "Exact match grading"
|
| 90 |
+
else: # multiple_choice
|
| 91 |
+
correct_str = str(correct_answer).strip().lower() if correct_answer else ""
|
| 92 |
+
is_correct = student_str == correct_str
|
| 93 |
+
score = answer.max_score if is_correct else 0
|
| 94 |
+
feedback = "Exact match grading"
|
| 95 |
+
else:
|
| 96 |
+
if self.use_ai_for_essays and correct_answer:
|
| 97 |
+
score, feedback = self.ai_semantic_grade(
|
| 98 |
+
answer.question_text,
|
| 99 |
+
answer.student_response,
|
| 100 |
+
correct_answer,
|
| 101 |
+
answer.max_score,
|
| 102 |
+
course=course
|
| 103 |
+
)
|
| 104 |
+
is_correct = score >= (answer.max_score * self.semantic_threshold)
|
| 105 |
+
else:
|
| 106 |
+
similarity = self.simple_similarity(
|
| 107 |
+
answer.student_response,
|
| 108 |
+
correct_answer
|
| 109 |
+
)
|
| 110 |
+
score = similarity * answer.max_score
|
| 111 |
+
is_correct = similarity >= self.semantic_threshold
|
| 112 |
+
feedback = f"Similarity score {similarity:.2f}"
|
| 113 |
+
|
| 114 |
+
return GradedAnswer(
|
| 115 |
+
question_no=answer.question_no,
|
| 116 |
+
question_type=answer.question_type,
|
| 117 |
+
question_text=answer.question_text,
|
| 118 |
+
student_response=answer.student_response,
|
| 119 |
+
correct_answer=correct_answer,
|
| 120 |
+
score=score,
|
| 121 |
+
max_score=answer.max_score,
|
| 122 |
+
feedback=feedback,
|
| 123 |
+
is_correct=is_correct
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def simple_similarity(self, student: str, correct: str) -> float:
|
| 127 |
+
if not student or not correct:
|
| 128 |
+
return 0
|
| 129 |
+
student_words = set(student.lower().split())
|
| 130 |
+
correct_words = set(correct.lower().split())
|
| 131 |
+
intersection = student_words.intersection(correct_words)
|
| 132 |
+
union = student_words.union(correct_words)
|
| 133 |
+
return len(intersection) / len(union)
|
| 134 |
+
|
| 135 |
+
def retrieve_context(self, question: str, course:str):
|
| 136 |
+
"""
|
| 137 |
+
Retrieve relevant context from Qdrant for a given question filtered by course
|
| 138 |
+
Args: question: The question text to embed and search for // course: Optional course filter
|
| 139 |
+
Returns: String containing concatenated context from top 3 chunks
|
| 140 |
+
"""
|
| 141 |
+
try:
|
| 142 |
+
controller = IndexingController()
|
| 143 |
+
embedding = controller.embedder.embed_text(question)
|
| 144 |
+
|
| 145 |
+
# Build metadata filters course
|
| 146 |
+
filters = []
|
| 147 |
+
if course:
|
| 148 |
+
filters.append({
|
| 149 |
+
"field": "course",
|
| 150 |
+
"op": "eq",
|
| 151 |
+
"value": course,
|
| 152 |
+
"clause": "must"
|
| 153 |
+
})
|
| 154 |
+
|
| 155 |
+
# Query Qdrant with filters
|
| 156 |
+
results = controller.vector_store.query_qdrant(embedding=embedding,filters=filters,top_k=5)
|
| 157 |
+
|
| 158 |
+
context = "\n".join(r["content"] for r in results if r.get("content"))
|
| 159 |
+
|
| 160 |
+
logger.info(f"Retrieved {len(results)} chunks for question (filtered by course={course})")
|
| 161 |
+
return context
|
| 162 |
+
|
| 163 |
+
except Exception as e:
|
| 164 |
+
logger.error(f"Context retrieval failed: {e}")
|
| 165 |
+
return ""
|
| 166 |
+
|
| 167 |
+
def build_prompt(self, question, student_answer, correct_answer, context):
|
| 168 |
+
return f"""
|
| 169 |
+
You are an academic exam grader.
|
| 170 |
+
|
| 171 |
+
Question:
|
| 172 |
+
{question}
|
| 173 |
+
|
| 174 |
+
Correct Answer:
|
| 175 |
+
{correct_answer}
|
| 176 |
+
|
| 177 |
+
Reference Material:
|
| 178 |
+
{context}
|
| 179 |
+
|
| 180 |
+
Student Answer:
|
| 181 |
+
{student_answer}
|
| 182 |
+
|
| 183 |
+
Evaluate the student answer using semantic similarity.
|
| 184 |
+
You may slightly use your knowledge if correct answer not in Reference Material.
|
| 185 |
+
|
| 186 |
+
Return JSON only:
|
| 187 |
+
|
| 188 |
+
{{
|
| 189 |
+
"score": number between 0 and 1,
|
| 190 |
+
"feedback": short explanation
|
| 191 |
+
}}
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
def parse_llm_output(self, text: str):
|
| 195 |
+
try:
|
| 196 |
+
if isinstance(text, dict):
|
| 197 |
+
if 'response' in text:
|
| 198 |
+
text = text['response']
|
| 199 |
+
else:
|
| 200 |
+
text = str(text)
|
| 201 |
+
elif hasattr(text, 'content'):
|
| 202 |
+
text = text.content
|
| 203 |
+
elif hasattr(text, 'text'):
|
| 204 |
+
text = text.text
|
| 205 |
+
text = str(text).strip()
|
| 206 |
+
if not text:
|
| 207 |
+
return 0, "Empty response from LLM"
|
| 208 |
+
text = re.sub(r'```json\s*|\s*```', '', text)
|
| 209 |
+
try:
|
| 210 |
+
data = json.loads(text)
|
| 211 |
+
except json.JSONDecodeError:
|
| 212 |
+
json_match = re.search(r'\{.*\}', text, re.DOTALL)
|
| 213 |
+
if json_match:
|
| 214 |
+
data = json.loads(json_match.group())
|
| 215 |
+
else:
|
| 216 |
+
raise
|
| 217 |
+
|
| 218 |
+
score = float(data.get("score", 0))
|
| 219 |
+
feedback = data.get("feedback", "")
|
| 220 |
+
score = max(0, min(score, 1))
|
| 221 |
+
return score, feedback
|
| 222 |
+
|
| 223 |
+
except Exception as e:
|
| 224 |
+
logger.error(f"Failed to parse LLM output: {e}, text type: {type(text)}")
|
| 225 |
+
return 0, "Failed to parse AI grading"
|
| 226 |
+
|
| 227 |
+
def ai_semantic_grade(self, question, student, correct, max_score, course):
|
| 228 |
+
"""
|
| 229 |
+
Grade an answer using AI with context from Qdrant.
|
| 230 |
+
Args: question: The question text // student: Student's answer // correct: Correct answer
|
| 231 |
+
max_score: Maximum score for this question // course: Optional course for filtering context
|
| 232 |
+
Returns: // Tuple of (score, feedback)
|
| 233 |
+
"""
|
| 234 |
+
try:
|
| 235 |
+
# Retrieve context filtered by username and course
|
| 236 |
+
context = self.retrieve_context(question, course)
|
| 237 |
+
|
| 238 |
+
prompt = self.build_prompt(question,student,correct,context)
|
| 239 |
+
|
| 240 |
+
response = self.llm.generate_text(prompt)
|
| 241 |
+
|
| 242 |
+
# Log response type for debugging
|
| 243 |
+
logger.info(f"Response type: {type(response)}")
|
| 244 |
+
|
| 245 |
+
score_ratio, feedback = self.parse_llm_output(response)
|
| 246 |
+
score = score_ratio * max_score
|
| 247 |
+
|
| 248 |
+
return score, feedback
|
| 249 |
+
|
| 250 |
+
except Exception as e:
|
| 251 |
+
logger.error(f"AI grading failed: {e}")
|
| 252 |
+
# Fallback to simple similarity
|
| 253 |
+
similarity = self.simple_similarity(student, correct)
|
| 254 |
+
return similarity * max_score, f"Fallback similarity grading: {similarity:.2f}"
|
| 255 |
+
|
| 256 |
+
@shared_task
|
| 257 |
+
def grade_exam_task(submission_dict: Dict[str, Any]):
|
| 258 |
+
submission = None
|
| 259 |
+
try:
|
| 260 |
+
submission = ExamSubmission(**submission_dict)
|
| 261 |
+
service = ExamGradingService()
|
| 262 |
+
result = service.grade_submission(submission)
|
| 263 |
+
result_dict = result.model_dump()
|
| 264 |
+
|
| 265 |
+
# Send webhook with grade only
|
| 266 |
+
try:
|
| 267 |
+
webhook_url = get_settings().GRADE_WEBHOOK_URL
|
| 268 |
+
print(f" Webhook URL: {webhook_url}")
|
| 269 |
+
|
| 270 |
+
if webhook_url:
|
| 271 |
+
# Create grade-only payload
|
| 272 |
+
grade_only_payload = {
|
| 273 |
+
"status": "completed",
|
| 274 |
+
"exam_id": submission.exam_id,
|
| 275 |
+
"student_id": submission.student_id,
|
| 276 |
+
"course_id":submission.course_id,
|
| 277 |
+
"grade": {
|
| 278 |
+
"total_score": result_dict['total_score'],
|
| 279 |
+
"max_total_score": result_dict['max_total_score'],
|
| 280 |
+
"percentage": result_dict['percentage'],
|
| 281 |
+
"grade": result_dict['grade'],
|
| 282 |
+
"graded_time": result_dict['graded_time']
|
| 283 |
+
},
|
| 284 |
+
"result" : result_dict,
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
response = httpx.post(
|
| 288 |
+
webhook_url,
|
| 289 |
+
json=grade_only_payload,
|
| 290 |
+
timeout=30.0
|
| 291 |
+
)
|
| 292 |
+
print(f" Response status: {response.status_code}")
|
| 293 |
+
|
| 294 |
+
if response.status_code == 200:
|
| 295 |
+
print(" Grade-only webhook sent successfully!")
|
| 296 |
+
else:
|
| 297 |
+
print(f" Webhook returned status: {response.status_code}")
|
| 298 |
+
print(f" Response: {response.text[:200]}")
|
| 299 |
+
else:
|
| 300 |
+
print("WEBHOOK_URL is empty or not set!")
|
| 301 |
+
|
| 302 |
+
except Exception as e:
|
| 303 |
+
print(f" Webhook error: {type(e).__name__}: {e}")
|
| 304 |
+
import traceback
|
| 305 |
+
traceback.print_exc()
|
| 306 |
+
|
| 307 |
+
print(" Task completed successfully")
|
| 308 |
+
return result_dict
|
| 309 |
+
|
| 310 |
+
except Exception as e:
|
| 311 |
+
print(f" ERROR in task: {type(e).__name__}: {e}")
|
| 312 |
+
import traceback
|
| 313 |
+
traceback.print_exc()
|
| 314 |
+
raise
|
generation/ExamRagGenerator.py
ADDED
|
@@ -0,0 +1,460 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict
|
| 2 |
+
import logging
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
import math
|
| 6 |
+
from json_repair import repair_json
|
| 7 |
+
from pydantic import parse_obj_as
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from config import get_settings
|
| 10 |
+
from routes.schemas.Exam_Models import *
|
| 11 |
+
from stores.llm.LLMProviderFactory import LLMProviderFactory
|
| 12 |
+
from generation.AssistantRagGenerator import ProviderLLMWrapper
|
| 13 |
+
from generation.prompts import ExamPromptBuilder
|
| 14 |
+
from indexing.indexingController import IndexingController
|
| 15 |
+
|
| 16 |
+
class ExamService:
|
| 17 |
+
MAX_CHUNK_CHARS = 2000
|
| 18 |
+
MAX_TOTAL_CONTEXT = 8000
|
| 19 |
+
MAX_SCORE = 40
|
| 20 |
+
PASS_THRESHOLD = int(MAX_SCORE * 0.8)
|
| 21 |
+
MAX_GENERATION_ATTEMPTS = 3
|
| 22 |
+
|
| 23 |
+
def __init__(self):
|
| 24 |
+
self.logger = logging.getLogger(__name__)
|
| 25 |
+
self._models_initialized = False
|
| 26 |
+
self.settings=get_settings()
|
| 27 |
+
self._init_models()
|
| 28 |
+
self.prompts=ExamPromptBuilder()
|
| 29 |
+
self.controller = IndexingController()
|
| 30 |
+
self.store = self.controller.vector_store
|
| 31 |
+
self.BATCH_SIZE=10
|
| 32 |
+
|
| 33 |
+
def _init_models(self):
|
| 34 |
+
if self._models_initialized:
|
| 35 |
+
return
|
| 36 |
+
factory = LLMProviderFactory(self.settings)
|
| 37 |
+
self.generator = factory.create(self.settings.GENERATION_BACKEND)
|
| 38 |
+
self.generator.set_generation_model(self.settings.GENERATION_MODEL_ID)
|
| 39 |
+
self.embedding_provider = factory.create(self.settings.EMBEDDING_BACKEND)
|
| 40 |
+
self.embedding_provider.set_embedding_model(
|
| 41 |
+
self.settings.EMBEDDING_MODEL_ID,
|
| 42 |
+
self.settings.EMBEDDING_MODEL_SIZE
|
| 43 |
+
)
|
| 44 |
+
self.llm = ProviderLLMWrapper(provider=self.generator)
|
| 45 |
+
self._models_initialized = True
|
| 46 |
+
|
| 47 |
+
def _extract_json(self, text: str) -> dict:
|
| 48 |
+
"""
|
| 49 |
+
Extract the first valid JSON object from LLM output. Attempts to repair malformed JSON using `repair_json`.
|
| 50 |
+
"""
|
| 51 |
+
match = re.search(r"\{.*\}", text, re.DOTALL)
|
| 52 |
+
if not match:
|
| 53 |
+
self.logger.error("No JSON found in LLM response:\n%s", text)
|
| 54 |
+
raise ValueError("LLM returned no JSON")
|
| 55 |
+
json_str = match.group(0)
|
| 56 |
+
# Try to load directly
|
| 57 |
+
try:
|
| 58 |
+
return json.loads(json_str)
|
| 59 |
+
except json.JSONDecodeError:
|
| 60 |
+
self.logger.warning("Invalid JSON extracted, attempting repair...")
|
| 61 |
+
try:
|
| 62 |
+
repaired_str = repair_json(json_str)
|
| 63 |
+
return json.loads(repaired_str)
|
| 64 |
+
except Exception as e:
|
| 65 |
+
self.logger.error("Failed to repair JSON:\n%s\nError: %s", json_str, e)
|
| 66 |
+
raise
|
| 67 |
+
|
| 68 |
+
def normalize_exam_dict(self, data: dict):
|
| 69 |
+
# Normalize difficulty enum
|
| 70 |
+
if "difficulty" in data:
|
| 71 |
+
diff = data["difficulty"]
|
| 72 |
+
if isinstance(diff, str):
|
| 73 |
+
if "." in diff:
|
| 74 |
+
diff = diff.split(".")[-1]
|
| 75 |
+
data["difficulty"] = diff.lower()
|
| 76 |
+
# Normalize questions
|
| 77 |
+
questions = data.get("questions")
|
| 78 |
+
if not isinstance(questions, list):
|
| 79 |
+
return data
|
| 80 |
+
normalized_questions = []
|
| 81 |
+
for q in questions:
|
| 82 |
+
if not isinstance(q, dict):
|
| 83 |
+
continue
|
| 84 |
+
q.pop("id", None)
|
| 85 |
+
q.pop("question_id", None)
|
| 86 |
+
q.pop("points", None)
|
| 87 |
+
|
| 88 |
+
# normalize type
|
| 89 |
+
q_type = q.get("type")
|
| 90 |
+
if isinstance(q_type, str):
|
| 91 |
+
q_type = q_type.lower().strip()
|
| 92 |
+
if q_type == "truefalse":
|
| 93 |
+
q_type = "true_false"
|
| 94 |
+
q["type"] = q_type
|
| 95 |
+
|
| 96 |
+
# normalize question text
|
| 97 |
+
if "question" in q:
|
| 98 |
+
q["question"] = str(q["question"]).strip()
|
| 99 |
+
|
| 100 |
+
# MCQ normalization
|
| 101 |
+
if q_type == "mcq":
|
| 102 |
+
options = q.get("options")
|
| 103 |
+
# dict -> list
|
| 104 |
+
if isinstance(options, dict):
|
| 105 |
+
options = list(options.values())
|
| 106 |
+
# string -> split into options
|
| 107 |
+
elif isinstance(options, str):
|
| 108 |
+
parts = re.split(r"[A-D]\)|\n|\r", options)
|
| 109 |
+
options = [
|
| 110 |
+
p.strip(" .-")
|
| 111 |
+
for p in parts
|
| 112 |
+
if p.strip()
|
| 113 |
+
]
|
| 114 |
+
# ensure list[str]
|
| 115 |
+
if isinstance(options, list):
|
| 116 |
+
options = [str(o).strip() for o in options]
|
| 117 |
+
else:
|
| 118 |
+
options = []
|
| 119 |
+
q["options"] = options
|
| 120 |
+
|
| 121 |
+
# normalize correct answer
|
| 122 |
+
correct = q.get("correct_answer")
|
| 123 |
+
if correct is not None:
|
| 124 |
+
correct = str(correct).strip()
|
| 125 |
+
q["correct_answer"] = correct
|
| 126 |
+
# ensure correct answer exists in options
|
| 127 |
+
if correct not in q["options"]:
|
| 128 |
+
q["options"].append(correct)
|
| 129 |
+
# ensure explanation exists
|
| 130 |
+
q.setdefault("explanation", "")
|
| 131 |
+
|
| 132 |
+
# True/False normalization
|
| 133 |
+
elif q_type == "true_false":
|
| 134 |
+
ans = q.get("correct_answer")
|
| 135 |
+
if isinstance(ans, str):
|
| 136 |
+
ans = ans.lower()
|
| 137 |
+
if ans in ["true", "t", "1", "yes"]:
|
| 138 |
+
ans = True
|
| 139 |
+
elif ans in ["false", "f", "0", "no"]:
|
| 140 |
+
ans = False
|
| 141 |
+
q["correct_answer"] = ans
|
| 142 |
+
q.setdefault("explanation", "")
|
| 143 |
+
|
| 144 |
+
# Short Answer normalization
|
| 145 |
+
elif q_type == "short_answer":
|
| 146 |
+
if "answer" in q:
|
| 147 |
+
q["answer"] = str(q["answer"]).strip()
|
| 148 |
+
q.setdefault("explanation", "")
|
| 149 |
+
|
| 150 |
+
# Essay normalization
|
| 151 |
+
elif q_type == "essay":
|
| 152 |
+
if "expected_keywords" in q:
|
| 153 |
+
keywords = q.pop("expected_keywords")
|
| 154 |
+
if isinstance(keywords, list):
|
| 155 |
+
q["answer_guidelines"] = ", ".join(keywords)
|
| 156 |
+
else:
|
| 157 |
+
q["answer_guidelines"] = str(keywords)
|
| 158 |
+
q.setdefault("answer_guidelines", "")
|
| 159 |
+
|
| 160 |
+
# Code question normalization
|
| 161 |
+
elif q_type == "code":
|
| 162 |
+
if "solution" in q:
|
| 163 |
+
q["solution"] = str(q["solution"])
|
| 164 |
+
q.setdefault("starter_code", None)
|
| 165 |
+
q.setdefault("explanation", "")
|
| 166 |
+
normalized_questions.append(q)
|
| 167 |
+
data["questions"] = normalized_questions
|
| 168 |
+
|
| 169 |
+
return data
|
| 170 |
+
|
| 171 |
+
def generate_exam(self, request: ExamGenerationRequest, context: str, llm, batch_size: int) -> List[QuestionUnion]:
|
| 172 |
+
"""
|
| 173 |
+
Generate a batch of questions from the LLM, ensuring valid QuestionUnion objects.Repairs incomplete MCQs automatically.
|
| 174 |
+
"""
|
| 175 |
+
# Prepare the prompt for the batch
|
| 176 |
+
batch_request = request.model_copy()
|
| 177 |
+
batch_request.total_questions = batch_size
|
| 178 |
+
|
| 179 |
+
prompt = self.prompts.build_exam_generation_prompt(batch_request, context)
|
| 180 |
+
raw_text = llm._call(prompt)
|
| 181 |
+
|
| 182 |
+
if not raw_text:
|
| 183 |
+
raise RuntimeError("LLM generation failed")
|
| 184 |
+
|
| 185 |
+
cleaned = re.sub(r"```[a-zA-Z]*|```", "", raw_text).strip()
|
| 186 |
+
|
| 187 |
+
try:
|
| 188 |
+
exam_dict = self._extract_json(cleaned)
|
| 189 |
+
exam_dict = self.normalize_exam_dict(exam_dict)
|
| 190 |
+
|
| 191 |
+
questions = exam_dict.get("questions") or []
|
| 192 |
+
questions = questions[:batch_size]
|
| 193 |
+
|
| 194 |
+
# Repair incomplete MCQs or missing fields
|
| 195 |
+
repaired_questions = []
|
| 196 |
+
for q in questions:
|
| 197 |
+
if not isinstance(q, dict):
|
| 198 |
+
continue # skip invalid entries
|
| 199 |
+
q_type = q.get("type")
|
| 200 |
+
if q_type == "mcq":
|
| 201 |
+
if not q.get("options"):
|
| 202 |
+
self.logger.warning(f"Skipping MCQ with no options: {q}")
|
| 203 |
+
continue
|
| 204 |
+
if not q.get("correct_answer"):
|
| 205 |
+
q["correct_answer"] = q["options"][0] # safe placeholder
|
| 206 |
+
repaired_questions.append(q)
|
| 207 |
+
|
| 208 |
+
# Convert to Pydantic QuestionUnion objects
|
| 209 |
+
questions = parse_obj_as(List[QuestionUnion], repaired_questions)
|
| 210 |
+
|
| 211 |
+
self.logger.info(
|
| 212 |
+
"Batch requested=%d | received=%d | kept=%d",
|
| 213 |
+
batch_size,
|
| 214 |
+
len(exam_dict.get("questions", [])),
|
| 215 |
+
len(questions),
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
except json.JSONDecodeError:
|
| 219 |
+
self.logger.error("Invalid JSON from LLM:\n%s", raw_text)
|
| 220 |
+
raise
|
| 221 |
+
|
| 222 |
+
return questions
|
| 223 |
+
|
| 224 |
+
def evaluate_exam(self, request: ExamGenerationRequest, exam: ExamResponse, llm):
|
| 225 |
+
prompt = self.prompts.build_exam_evaluation_prompt(request, exam)
|
| 226 |
+
raw_text = llm._call(prompt)
|
| 227 |
+
|
| 228 |
+
if not raw_text:
|
| 229 |
+
raise RuntimeError("Evaluation generation failed")
|
| 230 |
+
|
| 231 |
+
cleaned = re.sub(r"```[a-zA-Z]*|```", "", raw_text).strip()
|
| 232 |
+
|
| 233 |
+
try:
|
| 234 |
+
evaluation_dict = self._extract_json(cleaned)
|
| 235 |
+
except json.JSONDecodeError:
|
| 236 |
+
self.logger.error("Invalid evaluation JSON:\n%s", raw_text)
|
| 237 |
+
raise
|
| 238 |
+
|
| 239 |
+
return EvaluationResult.model_validate(evaluation_dict)
|
| 240 |
+
|
| 241 |
+
def split_chunks_by_topic_batches(self, exam_chunks, num_batches):
|
| 242 |
+
|
| 243 |
+
self.logger.info(f"Topics retrieved: {list(exam_chunks.keys())}")
|
| 244 |
+
self.logger.info(f"Number of batches: {num_batches}")
|
| 245 |
+
|
| 246 |
+
batches = [[] for _ in range(num_batches)]
|
| 247 |
+
|
| 248 |
+
for topic, chunks in exam_chunks.items():
|
| 249 |
+
total_chunks = len(chunks)
|
| 250 |
+
self.logger.info(f"Topic '{topic}' -> {total_chunks} chunks distributed across batches")
|
| 251 |
+
|
| 252 |
+
for idx, chunk in enumerate(chunks):
|
| 253 |
+
batch_index = idx % num_batches
|
| 254 |
+
batches[batch_index].append(chunk)
|
| 255 |
+
|
| 256 |
+
# Log batch composition
|
| 257 |
+
for i, batch in enumerate(batches):
|
| 258 |
+
topic_counter = defaultdict(int)
|
| 259 |
+
for chunk in batch:
|
| 260 |
+
topic = chunk.get("metadata", {}).get("topic", "unknown")
|
| 261 |
+
topic_counter[topic] += 1
|
| 262 |
+
self.logger.info(f"Batch {i+1} contains {len(batch)} chunks -> {dict(topic_counter)}")
|
| 263 |
+
|
| 264 |
+
return batches
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def exam_task(self, request_dict: dict) -> ExamResponse:
|
| 268 |
+
"""
|
| 269 |
+
Generate a full exam using batching, safety break, and validated QuestionUnion questions.Each batch receives a portion of the retrieved chunks.
|
| 270 |
+
"""
|
| 271 |
+
request = ExamGenerationRequest.model_validate(request_dict)
|
| 272 |
+
# Prepare context from knowledge store
|
| 273 |
+
topics_with_embeddings = self.prepare_topics_with_embeddings(request.topics)
|
| 274 |
+
exam_chunks = self.store.retrieve_for_exam(topics_with_embeddings,request.username,request.course,request.references)
|
| 275 |
+
|
| 276 |
+
# Determine number of batches
|
| 277 |
+
num_batches = math.ceil(request.total_questions / self.BATCH_SIZE)
|
| 278 |
+
self.logger.info(f"Raw exam_chunks structure: {type(exam_chunks)}")
|
| 279 |
+
|
| 280 |
+
for k, v in exam_chunks.items():
|
| 281 |
+
self.logger.info(f"Topic={k} | type={type(v)} | len={len(v) if hasattr(v,'__len__') else 'NA'}")
|
| 282 |
+
|
| 283 |
+
chunk_batches = self.split_chunks_by_topic_batches(exam_chunks,num_batches)
|
| 284 |
+
|
| 285 |
+
feedback_context = ""
|
| 286 |
+
|
| 287 |
+
best_exam = None
|
| 288 |
+
best_score = 0
|
| 289 |
+
|
| 290 |
+
for attempt in range(self.MAX_GENERATION_ATTEMPTS):
|
| 291 |
+
self.logger.info(f"Generating exam attempt {attempt+1}")
|
| 292 |
+
|
| 293 |
+
remaining_distribution: Dict[QuestionType, int] = dict(request.question_types_distribution)
|
| 294 |
+
all_questions: List[QuestionUnion] = []
|
| 295 |
+
batch_index = 0
|
| 296 |
+
|
| 297 |
+
# Batch generation loop
|
| 298 |
+
while len(all_questions) < request.total_questions:
|
| 299 |
+
remaining = request.total_questions - len(all_questions)
|
| 300 |
+
batch_size = min(self.BATCH_SIZE, remaining)
|
| 301 |
+
# Determine batch distribution
|
| 302 |
+
batch_distribution: Dict[QuestionType, int] = {}
|
| 303 |
+
slots_left = batch_size
|
| 304 |
+
|
| 305 |
+
for qtype, count in remaining_distribution.items():
|
| 306 |
+
if slots_left <= 0:
|
| 307 |
+
break
|
| 308 |
+
|
| 309 |
+
take = min(count, slots_left)
|
| 310 |
+
|
| 311 |
+
if take > 0:
|
| 312 |
+
batch_distribution[qtype] = take
|
| 313 |
+
slots_left -= take
|
| 314 |
+
|
| 315 |
+
if not batch_distribution:
|
| 316 |
+
break
|
| 317 |
+
|
| 318 |
+
batch_request = request.model_copy()
|
| 319 |
+
batch_request.total_questions = sum(batch_distribution.values())
|
| 320 |
+
batch_request.question_types_distribution = batch_distribution
|
| 321 |
+
|
| 322 |
+
# Select chunk subset for this batch
|
| 323 |
+
|
| 324 |
+
chunk_subset = chunk_batches[batch_index % len(chunk_batches)]
|
| 325 |
+
self.logger.info(f"\n===== BATCH {batch_index+1} CHUNKS =====")
|
| 326 |
+
|
| 327 |
+
for i, chunk in enumerate(chunk_subset):
|
| 328 |
+
|
| 329 |
+
meta = chunk.get("metadata", {})
|
| 330 |
+
topic = meta.get("topic", "unknown")
|
| 331 |
+
page = meta.get("page", "NA")
|
| 332 |
+
|
| 333 |
+
# Try common text keys
|
| 334 |
+
text = chunk.get("text") or chunk.get("content") or chunk.get("page_content") or ""
|
| 335 |
+
|
| 336 |
+
preview = text[:200].replace("\n", " ")
|
| 337 |
+
|
| 338 |
+
self.logger.info(
|
| 339 |
+
f"Chunk {i+1} | Topic={topic} | Page={page} | Preview={preview}"
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
self.logger.info("=====================================\n")
|
| 343 |
+
|
| 344 |
+
batch_index += 1
|
| 345 |
+
|
| 346 |
+
batch_context = self.build_exam_context(chunk_subset)
|
| 347 |
+
|
| 348 |
+
if feedback_context:
|
| 349 |
+
batch_context += f"\n\nEvaluator Feedback:\n{feedback_context}"
|
| 350 |
+
|
| 351 |
+
# Generate questions
|
| 352 |
+
|
| 353 |
+
batch_questions = self.generate_exam(batch_request,batch_context,self.llm,batch_request.total_questions)
|
| 354 |
+
|
| 355 |
+
# Filter generated questions
|
| 356 |
+
for q in batch_questions:
|
| 357 |
+
if remaining_distribution.get(q.type, 0) > 0:
|
| 358 |
+
all_questions.append(q)
|
| 359 |
+
remaining_distribution[q.type] -= 1
|
| 360 |
+
if len(all_questions) >= request.total_questions:
|
| 361 |
+
break
|
| 362 |
+
|
| 363 |
+
# Build final exam
|
| 364 |
+
|
| 365 |
+
exam_dict = {
|
| 366 |
+
"exam_id": request.exam_id,
|
| 367 |
+
"difficulty": request.difficulty,
|
| 368 |
+
"total_questions": request.total_questions,
|
| 369 |
+
"expected_distribution": request.question_types_distribution,
|
| 370 |
+
"questions": all_questions[:request.total_questions],
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
try:
|
| 374 |
+
exam = ExamResponse.model_validate(exam_dict)
|
| 375 |
+
except Exception as e:
|
| 376 |
+
self.logger.error(f"Exam validation failed: {e}")
|
| 377 |
+
raise
|
| 378 |
+
|
| 379 |
+
evaluation = self.evaluate_exam(request, exam, self.llm)
|
| 380 |
+
self.logger.info(f"Evaluation score: {evaluation.overall_score}")
|
| 381 |
+
|
| 382 |
+
if evaluation.overall_score > best_score:
|
| 383 |
+
best_score = evaluation.overall_score
|
| 384 |
+
best_exam = exam
|
| 385 |
+
|
| 386 |
+
if evaluation.overall_score >= self.PASS_THRESHOLD:
|
| 387 |
+
break
|
| 388 |
+
|
| 389 |
+
feedback_context = evaluation.feedback
|
| 390 |
+
|
| 391 |
+
if best_exam is None:
|
| 392 |
+
raise RuntimeError("Exam generation failed after retries")
|
| 393 |
+
|
| 394 |
+
return best_exam
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def build_exam_context(self, exam_chunks) -> str:
|
| 399 |
+
"""
|
| 400 |
+
Accepts either:
|
| 401 |
+
1) {topic: [chunks]}
|
| 402 |
+
2) [chunks]
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
# Normalize structure
|
| 406 |
+
if isinstance(exam_chunks, list):
|
| 407 |
+
topic_chunks = defaultdict(list)
|
| 408 |
+
|
| 409 |
+
for c in exam_chunks:
|
| 410 |
+
topic = c.get("metadata", {}).get("topic", "Unknown")
|
| 411 |
+
topic_chunks[topic].append(c)
|
| 412 |
+
|
| 413 |
+
exam_chunks = topic_chunks
|
| 414 |
+
|
| 415 |
+
context_parts = []
|
| 416 |
+
total_length = 0
|
| 417 |
+
|
| 418 |
+
for topic, chunks in exam_chunks.items():
|
| 419 |
+
|
| 420 |
+
topic_header = f"\n### Topic: {topic}\n"
|
| 421 |
+
|
| 422 |
+
if total_length + len(topic_header) > self.MAX_TOTAL_CONTEXT:
|
| 423 |
+
break
|
| 424 |
+
|
| 425 |
+
context_parts.append(topic_header)
|
| 426 |
+
total_length += len(topic_header)
|
| 427 |
+
|
| 428 |
+
for c in chunks:
|
| 429 |
+
|
| 430 |
+
text = c.get("payload", {}).get("text", "")
|
| 431 |
+
source = c.get("metadata", {}).get("source", "")
|
| 432 |
+
bookmark = c.get("metadata", {}).get("bookmark_path", "")
|
| 433 |
+
|
| 434 |
+
if not isinstance(text, str):
|
| 435 |
+
continue
|
| 436 |
+
|
| 437 |
+
if len(text) > self.MAX_CHUNK_CHARS:
|
| 438 |
+
text = text[:self.MAX_CHUNK_CHARS]
|
| 439 |
+
|
| 440 |
+
formatted_chunk = (f"[Source: {source} | Bookmark: {bookmark}]\n{text}\n")
|
| 441 |
+
|
| 442 |
+
if total_length + len(formatted_chunk) > self.MAX_TOTAL_CONTEXT:
|
| 443 |
+
break
|
| 444 |
+
|
| 445 |
+
context_parts.append(formatted_chunk)
|
| 446 |
+
total_length += len(formatted_chunk)
|
| 447 |
+
|
| 448 |
+
return "\n".join(context_parts)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def prepare_topics_with_embeddings(self, topics: List[str]):
|
| 452 |
+
results = []
|
| 453 |
+
for topic in topics:
|
| 454 |
+
try:
|
| 455 |
+
embedding = self.embedding_provider.embed_text(topic)
|
| 456 |
+
results.append((topic, embedding))
|
| 457 |
+
except Exception as e:
|
| 458 |
+
self.logger.warning(f"Embedding failed for topic '{topic}': {e}")
|
| 459 |
+
self.logger.info(f"Prepared {len(results)} topic embeddings")
|
| 460 |
+
return results
|
generation/__init__.py
ADDED
|
File without changes
|
generation/answer_models.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List, Optional, Dict, Any, Union
|
| 3 |
+
from enum import Enum
|
| 4 |
+
|
| 5 |
+
class QuestionType(str, Enum):
|
| 6 |
+
MULTIPLE_CHOICE = "multiple_choice"
|
| 7 |
+
TRUE_FALSE = "true_false"
|
| 8 |
+
SHORT_ANSWER = "short_answer"
|
| 9 |
+
ESSAY = "essay"
|
| 10 |
+
CODE = "code"
|
| 11 |
+
|
| 12 |
+
class StudentAnswer(BaseModel):
|
| 13 |
+
question_no: int
|
| 14 |
+
question_type: QuestionType
|
| 15 |
+
question_text: str
|
| 16 |
+
student_response: str
|
| 17 |
+
max_score: float = 1.0
|
| 18 |
+
metadata: Optional[Dict[str, Any]] = {}
|
| 19 |
+
|
| 20 |
+
class GradedAnswer(BaseModel):
|
| 21 |
+
question_no: int
|
| 22 |
+
question_type: QuestionType
|
| 23 |
+
question_text: str
|
| 24 |
+
student_response: str
|
| 25 |
+
correct_answer: Optional[Any]
|
| 26 |
+
score: float
|
| 27 |
+
max_score: float
|
| 28 |
+
feedback: str
|
| 29 |
+
is_correct: bool
|
| 30 |
+
|
| 31 |
+
class ExamSubmission(BaseModel):
|
| 32 |
+
exam_id: str
|
| 33 |
+
course_id: str
|
| 34 |
+
student_id: str
|
| 35 |
+
student_name: Optional[str]
|
| 36 |
+
answers: List[StudentAnswer]
|
| 37 |
+
submission_time: str
|
| 38 |
+
metadata: Optional[Dict[str, Any]] = {}
|
| 39 |
+
|
| 40 |
+
class ExamResult(BaseModel):
|
| 41 |
+
exam_id: str
|
| 42 |
+
student_id: str
|
| 43 |
+
student_name: Optional[str]
|
| 44 |
+
graded_answers: List[GradedAnswer]
|
| 45 |
+
total_score: float
|
| 46 |
+
max_total_score: float
|
| 47 |
+
percentage: float
|
| 48 |
+
grade: Optional[str]
|
| 49 |
+
feedback_summary: Optional[str]
|
| 50 |
+
submission_time: str
|
| 51 |
+
graded_time: str
|
generation/parsing_utils.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import logging
|
| 3 |
+
from typing import Any, Dict, Optional
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger("ExamGraph")
|
| 6 |
+
|
| 7 |
+
def safe_parse(parser_obj, text: str, question_no: int) -> Optional[Dict[str, Any]]:
|
| 8 |
+
if not text or text.strip() in ("null", "None", ""):
|
| 9 |
+
logger.warning(f"[Parse] q{question_no}: empty/null response")
|
| 10 |
+
return None
|
| 11 |
+
|
| 12 |
+
last_error = None
|
| 13 |
+
|
| 14 |
+
# Try direct parse
|
| 15 |
+
try:
|
| 16 |
+
result = parser_obj.parse(text)
|
| 17 |
+
return result.model_dump() if hasattr(result, "model_dump") else result
|
| 18 |
+
except Exception as e:
|
| 19 |
+
last_error = e
|
| 20 |
+
logger.debug(f"[Parse] q{question_no}: direct parse failed, trying extraction")
|
| 21 |
+
|
| 22 |
+
# Try to extract JSON from text (LLM may have wrapped it in prose)
|
| 23 |
+
try:
|
| 24 |
+
# look for {...} pattern
|
| 25 |
+
start = text.rfind("{")
|
| 26 |
+
end = text.rfind("}") + 1
|
| 27 |
+
if start >= 0 and end > start:
|
| 28 |
+
json_str = text[start:end]
|
| 29 |
+
json_obj = json.loads(json_str)
|
| 30 |
+
result = parser_obj.parse(json.dumps(json_obj))
|
| 31 |
+
return result.model_dump() if hasattr(result, "model_dump") else result
|
| 32 |
+
except Exception as e:
|
| 33 |
+
last_error = e
|
| 34 |
+
logger.debug(f"[Parse] q{question_no}: json extraction failed")
|
| 35 |
+
|
| 36 |
+
# Last resort: if it looks like partial JSON, mark for regen
|
| 37 |
+
error_msg = str(last_error) if last_error else "unknown"
|
| 38 |
+
logger.error(f"[Parse] q{question_no}: failed all attempts: {error_msg}")
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
def categorize_error(error_str: str) -> str:
|
| 42 |
+
err = error_str.lower()
|
| 43 |
+
if "timeout" in err:
|
| 44 |
+
return "timeout"
|
| 45 |
+
elif "json" in err or "invalid" in err:
|
| 46 |
+
return "invalid_json"
|
| 47 |
+
elif "field required" in err or "missing" in err:
|
| 48 |
+
return "missing_field"
|
| 49 |
+
elif "none" in err or "null" in err:
|
| 50 |
+
return "null"
|
| 51 |
+
return "unknown"
|
generation/prompts.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from generation.ExamRagGenerator import ExamGenerationRequest, ExamResponse
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
class ExamPromptBuilder:
|
| 5 |
+
MAX_SCORE = 40
|
| 6 |
+
|
| 7 |
+
def build_exam_generation_prompt(self,request: ExamGenerationRequest,context: str) -> str:
|
| 8 |
+
distribution = {
|
| 9 |
+
q_type.value: count
|
| 10 |
+
for q_type, count in request.question_types_distribution.items()
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
return f"""
|
| 14 |
+
You are an automated exam generation system.
|
| 15 |
+
|
| 16 |
+
Your job is to produce a structured exam strictly following the schema below.
|
| 17 |
+
|
| 18 |
+
----------------------------------------------------
|
| 19 |
+
CRITICAL OUTPUT RULES
|
| 20 |
+
----------------------------------------------------
|
| 21 |
+
|
| 22 |
+
You MUST return ONLY a valid JSON object.
|
| 23 |
+
|
| 24 |
+
Do NOT include:
|
| 25 |
+
|
| 26 |
+
- explanations
|
| 27 |
+
- markdown
|
| 28 |
+
- comments
|
| 29 |
+
- code blocks
|
| 30 |
+
- text before or after the JSON
|
| 31 |
+
|
| 32 |
+
The response MUST start with {{ and end with }}.
|
| 33 |
+
|
| 34 |
+
If the output is not valid JSON the result will be rejected.
|
| 35 |
+
|
| 36 |
+
----------------------------------------------------
|
| 37 |
+
ENUM VALUES (STRICT)
|
| 38 |
+
----------------------------------------------------
|
| 39 |
+
|
| 40 |
+
difficulty must be exactly one of:
|
| 41 |
+
|
| 42 |
+
easy
|
| 43 |
+
medium
|
| 44 |
+
hard
|
| 45 |
+
|
| 46 |
+
question type must be exactly one of:
|
| 47 |
+
|
| 48 |
+
mcq
|
| 49 |
+
true_false
|
| 50 |
+
short_answer
|
| 51 |
+
essay
|
| 52 |
+
code
|
| 53 |
+
|
| 54 |
+
----------------------------------------------------
|
| 55 |
+
EXAM REQUIREMENTS
|
| 56 |
+
----------------------------------------------------
|
| 57 |
+
|
| 58 |
+
course: {request.course}
|
| 59 |
+
|
| 60 |
+
difficulty: {request.difficulty.value}
|
| 61 |
+
|
| 62 |
+
total_questions: {request.total_questions}
|
| 63 |
+
|
| 64 |
+
question_types_distribution:
|
| 65 |
+
{json.dumps(distribution)}
|
| 66 |
+
|
| 67 |
+
You MUST generate exactly:
|
| 68 |
+
|
| 69 |
+
{json.dumps(distribution)}
|
| 70 |
+
|
| 71 |
+
Example:
|
| 72 |
+
|
| 73 |
+
{{
|
| 74 |
+
"mcq": 3,
|
| 75 |
+
"essay": 2
|
| 76 |
+
}}
|
| 77 |
+
|
| 78 |
+
means exactly:
|
| 79 |
+
3 mcq questions
|
| 80 |
+
2 essay questions
|
| 81 |
+
|
| 82 |
+
----------------------------------------------------
|
| 83 |
+
CONTEXT
|
| 84 |
+
----------------------------------------------------
|
| 85 |
+
|
| 86 |
+
Use ONLY the information from this context when creating questions.
|
| 87 |
+
|
| 88 |
+
{context}
|
| 89 |
+
|
| 90 |
+
----------------------------------------------------
|
| 91 |
+
QUESTION RULES
|
| 92 |
+
----------------------------------------------------
|
| 93 |
+
|
| 94 |
+
MCQ QUESTIONS
|
| 95 |
+
|
| 96 |
+
- must contain exactly 4 options
|
| 97 |
+
- options must be plain text
|
| 98 |
+
- correct_answer must match one option EXACTLY
|
| 99 |
+
- do NOT use letters like A/B/C/D
|
| 100 |
+
- do NOT include numbering inside options
|
| 101 |
+
|
| 102 |
+
Example:
|
| 103 |
+
|
| 104 |
+
{{
|
| 105 |
+
"type": "mcq",
|
| 106 |
+
"question": "What is 2 + 2?",
|
| 107 |
+
"options": ["1","2","3","4"],
|
| 108 |
+
"correct_answer": "4",
|
| 109 |
+
"explanation": "2 + 2 equals 4"
|
| 110 |
+
}}
|
| 111 |
+
|
| 112 |
+
----------------------------------------------------
|
| 113 |
+
|
| 114 |
+
TRUE/FALSE QUESTIONS
|
| 115 |
+
|
| 116 |
+
correct_answer must be boolean.
|
| 117 |
+
|
| 118 |
+
Example:
|
| 119 |
+
|
| 120 |
+
{{
|
| 121 |
+
"type": "true_false",
|
| 122 |
+
"question": "The Earth revolves around the Sun.",
|
| 123 |
+
"correct_answer": true,
|
| 124 |
+
"explanation": "Astronomy confirms this."
|
| 125 |
+
}}
|
| 126 |
+
|
| 127 |
+
----------------------------------------------------
|
| 128 |
+
|
| 129 |
+
SHORT ANSWER QUESTIONS
|
| 130 |
+
|
| 131 |
+
Example:
|
| 132 |
+
|
| 133 |
+
{{
|
| 134 |
+
"type": "short_answer",
|
| 135 |
+
"question": "Define photosynthesis.",
|
| 136 |
+
"answer": "Process where plants convert light into chemical energy",
|
| 137 |
+
"explanation": "Occurs in chloroplasts using sunlight"
|
| 138 |
+
}}
|
| 139 |
+
|
| 140 |
+
----------------------------------------------------
|
| 141 |
+
|
| 142 |
+
ESSAY QUESTIONS
|
| 143 |
+
|
| 144 |
+
Example:
|
| 145 |
+
|
| 146 |
+
{{
|
| 147 |
+
"type": "essay",
|
| 148 |
+
"question": "Explain Newton's First Law.",
|
| 149 |
+
"answer": "Newton's First Law states that an object will remain at rest or continue moving in a straight line at constant velocity unless acted upon by an external force. This property is called inertia. For example, a book on a table stays at rest until someone pushes it, and a moving car continues moving until friction or braking stops it.",
|
| 150 |
+
"answer_guidelines": "Describe inertia and provide examples"
|
| 151 |
+
}}
|
| 152 |
+
|
| 153 |
+
----------------------------------------------------
|
| 154 |
+
|
| 155 |
+
CODE QUESTIONS
|
| 156 |
+
|
| 157 |
+
Rules:
|
| 158 |
+
|
| 159 |
+
starter_code must be either a string OR null.
|
| 160 |
+
Never output the string "None".
|
| 161 |
+
|
| 162 |
+
Example:
|
| 163 |
+
|
| 164 |
+
{{
|
| 165 |
+
"type": "code",
|
| 166 |
+
"question": "Write a Python function to compute factorial.",
|
| 167 |
+
"language": "c",
|
| 168 |
+
"starter_code": "def factorial(n):",
|
| 169 |
+
"solution": "def factorial(n): return 1 if n<=1 else n*factorial(n-1)",
|
| 170 |
+
"explanation": "Uses recursion"
|
| 171 |
+
}}
|
| 172 |
+
|
| 173 |
+
----------------------------------------------------
|
| 174 |
+
IMPORTANT RESTRICTIONS
|
| 175 |
+
----------------------------------------------------
|
| 176 |
+
|
| 177 |
+
Do NOT output:
|
| 178 |
+
|
| 179 |
+
LaTeX
|
| 180 |
+
math formulas
|
| 181 |
+
markdown
|
| 182 |
+
additional fields
|
| 183 |
+
|
| 184 |
+
Use plain text only.
|
| 185 |
+
|
| 186 |
+
----------------------------------------------------
|
| 187 |
+
FINAL JSON STRUCTURE
|
| 188 |
+
----------------------------------------------------
|
| 189 |
+
|
| 190 |
+
{{
|
| 191 |
+
"exam_id": "{request.exam_id}",
|
| 192 |
+
"difficulty": "{request.difficulty.value}",
|
| 193 |
+
"total_questions": {request.total_questions},
|
| 194 |
+
"expected_distribution": {json.dumps(distribution)},
|
| 195 |
+
"questions": []
|
| 196 |
+
}}
|
| 197 |
+
|
| 198 |
+
Fill the questions array with the generated questions.
|
| 199 |
+
|
| 200 |
+
----------------------------------------------------
|
| 201 |
+
|
| 202 |
+
Return ONLY the JSON object.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
def build_exam_evaluation_prompt(self,request: ExamGenerationRequest,exam: ExamResponse) -> str:
|
| 206 |
+
|
| 207 |
+
exam_json = exam.model_dump_json()
|
| 208 |
+
|
| 209 |
+
return f"""
|
| 210 |
+
You are an exam quality evaluator.
|
| 211 |
+
|
| 212 |
+
--------------------------------
|
| 213 |
+
OUTPUT RULES
|
| 214 |
+
--------------------------------
|
| 215 |
+
1. Output MUST be valid JSON.
|
| 216 |
+
2. Do NOT include markdown.
|
| 217 |
+
3. Do NOT include reasoning outside JSON.
|
| 218 |
+
4. Output ONLY the JSON object.
|
| 219 |
+
5. JSON must start with {{ and end with }}.
|
| 220 |
+
|
| 221 |
+
--------------------------------
|
| 222 |
+
SCORING RANGE
|
| 223 |
+
--------------------------------
|
| 224 |
+
0 to {self.MAX_SCORE}
|
| 225 |
+
|
| 226 |
+
--------------------------------
|
| 227 |
+
EVALUATION CRITERIA
|
| 228 |
+
--------------------------------
|
| 229 |
+
1. Relevance of questions to the topics
|
| 230 |
+
2. Correct distribution of question types
|
| 231 |
+
3. Clarity and wording of questions
|
| 232 |
+
4. Difficulty consistency
|
| 233 |
+
5. Correctness of answers
|
| 234 |
+
|
| 235 |
+
--------------------------------
|
| 236 |
+
EXAM TO EVALUATE
|
| 237 |
+
--------------------------------
|
| 238 |
+
{exam_json}
|
| 239 |
+
|
| 240 |
+
--------------------------------
|
| 241 |
+
OUTPUT FORMAT
|
| 242 |
+
--------------------------------
|
| 243 |
+
|
| 244 |
+
{{
|
| 245 |
+
"overall_score": integer between 0 and {self.MAX_SCORE},
|
| 246 |
+
"feedback": "short explanation of issues if any"
|
| 247 |
+
}}
|
| 248 |
+
|
| 249 |
+
Return ONLY JSON.
|
| 250 |
+
"""
|
indexing/indexingController.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stores.llm.LLMProviderFactory import LLMProviderFactory
|
| 2 |
+
from stores.vector_store.Qdrant import QdrantStore
|
| 3 |
+
|
| 4 |
+
from ingestion.loaders.File_loader import load_file
|
| 5 |
+
from ingestion.chunkers.recursive_chunker import recursive_chunk
|
| 6 |
+
from ingestion.pdf_outline import extract_pdf_outline, build_page_bookmark_map , recursive_chunk_with_pages
|
| 7 |
+
from ingestion.loaders.pdf_loader import load_pdf_with_pages
|
| 8 |
+
|
| 9 |
+
from config import get_settings
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
from qdrant_client import QdrantClient , models
|
| 13 |
+
|
| 14 |
+
class IndexingController:
|
| 15 |
+
def __init__(self):
|
| 16 |
+
config = get_settings()
|
| 17 |
+
self.factory = LLMProviderFactory(config)
|
| 18 |
+
self.embedder = self.factory.create(config.EMBEDDING_BACKEND)
|
| 19 |
+
self.embedder.set_embedding_model(config.EMBEDDING_MODEL_ID, config.EMBEDDING_MODEL_SIZE)
|
| 20 |
+
if config.QDRANT_TYPE == "cloud":
|
| 21 |
+
self.vector_store_client = QdrantClient(url=config.QDRANT_DOCKER_URL,api_key=config.QDRANT_API_KEY,timeout=120)
|
| 22 |
+
elif config.QDRANT_TYPE == "docker":
|
| 23 |
+
self.vector_store_client = QdrantClient(url=config.QDRANT_DOCKER_URL,timeout=120)
|
| 24 |
+
elif config.QDRANT_TYPE == "local":
|
| 25 |
+
self.vector_store_client = QdrantClient(path="data/qdrant",prefer_grpc=False,timeout=120)
|
| 26 |
+
|
| 27 |
+
string_fields = ["metadata.username", "metadata.source", "metadata.course","metadata.bookmark_path"]
|
| 28 |
+
|
| 29 |
+
if not self.vector_store_client.collection_exists(collection_name=get_settings().QDRANT_COLLECTION):
|
| 30 |
+
# 2. Create the collection if it doesn't
|
| 31 |
+
self.vector_store_client.create_collection(
|
| 32 |
+
collection_name=get_settings().QDRANT_COLLECTION,
|
| 33 |
+
vectors_config=models.VectorParams(
|
| 34 |
+
size=get_settings().EMBEDDING_MODEL_SIZE,
|
| 35 |
+
distance=models.Distance.COSINE
|
| 36 |
+
),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
for field in string_fields:
|
| 40 |
+
self.vector_store_client.create_payload_index(
|
| 41 |
+
collection_name=get_settings().QDRANT_COLLECTION,
|
| 42 |
+
field_name=field,
|
| 43 |
+
field_schema=models.KeywordIndexParams(
|
| 44 |
+
type=models.KeywordIndexType.KEYWORD
|
| 45 |
+
)
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
self.vector_store= QdrantStore(self.vector_store_client,config.QDRANT_COLLECTION, config.EMBEDDING_MODEL_SIZE)
|
| 49 |
+
|
| 50 |
+
def embed_chunks(self, chunks):
|
| 51 |
+
return self.embedder.embed_text_batch(chunks)
|
| 52 |
+
|
| 53 |
+
def process_file(self,file_path, original_filename, username=None, course=None):
|
| 54 |
+
file_name = os.path.basename(file_path)
|
| 55 |
+
ext = os.path.splitext(file_path)[1].lower()
|
| 56 |
+
|
| 57 |
+
bookmark_map = {}
|
| 58 |
+
|
| 59 |
+
if ext == ".pdf":
|
| 60 |
+
outline , total_pages= extract_pdf_outline(file_path)
|
| 61 |
+
bookmark_map = build_page_bookmark_map(outline , total_pages)
|
| 62 |
+
|
| 63 |
+
pages = load_pdf_with_pages(file_path)
|
| 64 |
+
chunks = recursive_chunk_with_pages(pages)
|
| 65 |
+
|
| 66 |
+
else:
|
| 67 |
+
text = load_file(file_path)
|
| 68 |
+
if isinstance(text, list):
|
| 69 |
+
text = " ".join([doc.page_content for doc in text])
|
| 70 |
+
chunks_text = recursive_chunk(text)
|
| 71 |
+
chunks = [{"text": c, "page": None} for c in chunks_text]
|
| 72 |
+
|
| 73 |
+
embeddings = self.embed_chunks([c["text"] for c in chunks])
|
| 74 |
+
|
| 75 |
+
valid_embs = []
|
| 76 |
+
valid_payloads = []
|
| 77 |
+
|
| 78 |
+
for idx, (chunk_obj, emb) in enumerate(zip(chunks, embeddings)):
|
| 79 |
+
if emb is not None:
|
| 80 |
+
page = chunk_obj["page"]
|
| 81 |
+
bookmark_path = bookmark_map.get(page, [])
|
| 82 |
+
|
| 83 |
+
valid_embs.append(emb)
|
| 84 |
+
valid_payloads.append({
|
| 85 |
+
"content": chunk_obj["text"],
|
| 86 |
+
"metadata": {
|
| 87 |
+
"source": original_filename,
|
| 88 |
+
"chunk_index": idx,
|
| 89 |
+
"total_chunks": len(chunks),
|
| 90 |
+
"username": username,
|
| 91 |
+
"course": course,
|
| 92 |
+
"page": page,
|
| 93 |
+
"bookmark_path": bookmark_path,
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
)
|
| 97 |
+
print(f"[DEBUG] Prepared payload for chunk {idx}: page={page}, bookmark_path={bookmark_path}")
|
| 98 |
+
|
| 99 |
+
self.vector_store.upsert_embeddings(
|
| 100 |
+
self.vector_store_client,
|
| 101 |
+
get_settings().QDRANT_COLLECTION,
|
| 102 |
+
valid_embs,
|
| 103 |
+
valid_payloads
|
| 104 |
+
)
|
| 105 |
+
print(f"[INFO] Stored {len(valid_embs)} embeddings for file '{file_name}'.")
|
| 106 |
+
|
| 107 |
+
return {
|
| 108 |
+
"num_chunks": len(chunks),
|
| 109 |
+
"chunks": chunks,
|
| 110 |
+
"embeddings": embeddings
|
| 111 |
+
}
|
ingestion/chunkers/__init__.py
ADDED
|
File without changes
|
ingestion/chunkers/fixed_chunker.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_text_splitters import CharacterTextSplitter
|
| 2 |
+
from config import get_settings
|
| 3 |
+
|
| 4 |
+
def fixed_chunk(text):
|
| 5 |
+
splitter = CharacterTextSplitter(
|
| 6 |
+
chunk_size=get_settings().CHUNK_SIZE,
|
| 7 |
+
chunk_overlap=get_settings().CHUNK_OVERLAP
|
| 8 |
+
)
|
| 9 |
+
chunks = splitter.split_text(text)
|
| 10 |
+
return chunks
|
ingestion/chunkers/recursive_chunker.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 2 |
+
from config import get_settings
|
| 3 |
+
|
| 4 |
+
def recursive_chunk(text):
|
| 5 |
+
splitter = RecursiveCharacterTextSplitter(
|
| 6 |
+
chunk_size=get_settings().CHUNK_SIZE,
|
| 7 |
+
chunk_overlap=get_settings().CHUNK_OVERLAP,
|
| 8 |
+
)
|
| 9 |
+
return splitter.split_text(text)
|
ingestion/loaders/File_loader.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from config import get_settings
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def get_file_extension(file_id: str):
|
| 5 |
+
return os.path.splitext(file_id)[-1]
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_file(file_path: str):
|
| 9 |
+
if get_settings().CustomLoaders==True:
|
| 10 |
+
from ingestion.loaders.pdf_loader import load_pdf
|
| 11 |
+
from ingestion.loaders.txt_loader import load_txt
|
| 12 |
+
from ingestion.loaders.md_loader import load_md
|
| 13 |
+
from ingestion.loaders.docx_loader import load_docx
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
#Dispatcher
|
| 17 |
+
|
| 18 |
+
ext = os.path.splitext(file_path)[1].lower()
|
| 19 |
+
|
| 20 |
+
if ext == ".pdf":
|
| 21 |
+
docs = load_pdf(file_path)
|
| 22 |
+
elif ext == ".docx":
|
| 23 |
+
docs = load_docx(file_path)
|
| 24 |
+
elif ext == ".md":
|
| 25 |
+
docs = load_md(file_path)
|
| 26 |
+
elif ext == ".txt":
|
| 27 |
+
docs = load_txt(file_path)
|
| 28 |
+
else:
|
| 29 |
+
print(f"Unsupported file type: {ext}")
|
| 30 |
+
return []
|
| 31 |
+
|
| 32 |
+
# Return list of Document objects as-is
|
| 33 |
+
return docs
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
elif get_settings().CustomLoaders==False:
|
| 37 |
+
|
| 38 |
+
from langchain_community.document_loaders import (
|
| 39 |
+
TextLoader,
|
| 40 |
+
Docx2txtLoader,
|
| 41 |
+
UnstructuredMarkdownLoader,
|
| 42 |
+
PyMuPDFLoader,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
extension = get_file_extension(file_path)
|
| 47 |
+
|
| 48 |
+
if extension == ".txt":
|
| 49 |
+
return TextLoader(file_path, encoding="utf8").load()
|
| 50 |
+
elif extension == ".docx":
|
| 51 |
+
return Docx2txtLoader(file_path).load()
|
| 52 |
+
elif extension == ".md":
|
| 53 |
+
return UnstructuredMarkdownLoader(file_path).load()
|
| 54 |
+
elif extension in [".pdf"]:
|
| 55 |
+
return PyMuPDFLoader(file_path).load()
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(f"Unsupported file extension: {extension}")
|
ingestion/loaders/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
ingestion/loaders/docx_loader.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
from docx import Document as DocxDocument
|
| 5 |
+
from docx.oxml.table import CT_Tbl
|
| 6 |
+
from docx.oxml.text.paragraph import CT_P
|
| 7 |
+
from ingestion.loaders.normalization import normalize_text
|
| 8 |
+
|
| 9 |
+
def table_to_text(table) -> str:
|
| 10 |
+
"""Convert DOCX table to plain, readable text without numeric headers."""
|
| 11 |
+
data = []
|
| 12 |
+
try:
|
| 13 |
+
for row in table.rows:
|
| 14 |
+
row_data = [normalize_text(cell.text) for cell in row.cells]
|
| 15 |
+
if any(row_data): # skip empty rows
|
| 16 |
+
data.append(row_data)
|
| 17 |
+
|
| 18 |
+
if not data:
|
| 19 |
+
return ""
|
| 20 |
+
|
| 21 |
+
# Format as a readable markdown-like table instead of CSV with numbers
|
| 22 |
+
return "\n".join([" | ".join(row) for row in data])
|
| 23 |
+
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print(f"Error converting table to text: {e}")
|
| 26 |
+
return ""
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_docx(file_path: str) -> List[Document]:
|
| 32 |
+
"""Load DOCX file safely, preserving tables and skipping corrupted sections."""
|
| 33 |
+
docs = []
|
| 34 |
+
|
| 35 |
+
if not os.path.exists(file_path):
|
| 36 |
+
print(f"File not found: {file_path}")
|
| 37 |
+
return []
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
doc = DocxDocument(file_path)
|
| 41 |
+
except Exception as e:
|
| 42 |
+
print(f"Failed to open DOCX ({file_path}): {e}")
|
| 43 |
+
return []
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
body_elements = list(doc.element.body)
|
| 47 |
+
paragraph_iter = iter(doc.paragraphs)
|
| 48 |
+
table_iter = iter(doc.tables)
|
| 49 |
+
|
| 50 |
+
for element in body_elements:
|
| 51 |
+
if isinstance(element, CT_P):
|
| 52 |
+
try:
|
| 53 |
+
para = next(paragraph_iter)
|
| 54 |
+
cleaned = normalize_text(para.text)
|
| 55 |
+
if cleaned:
|
| 56 |
+
docs.append(
|
| 57 |
+
Document(
|
| 58 |
+
page_content=cleaned,
|
| 59 |
+
metadata={"source": file_path, "type": "text"},
|
| 60 |
+
)
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
except StopIteration:
|
| 64 |
+
continue
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"Error reading paragraph: {e}")
|
| 67 |
+
continue
|
| 68 |
+
elif isinstance(element, CT_Tbl):
|
| 69 |
+
try:
|
| 70 |
+
table = next(table_iter)
|
| 71 |
+
table_text = table_to_text(table)
|
| 72 |
+
if table_text:
|
| 73 |
+
docs.append(
|
| 74 |
+
Document(
|
| 75 |
+
page_content=table_text,
|
| 76 |
+
metadata={"source": file_path, "type": "table"},
|
| 77 |
+
)
|
| 78 |
+
)
|
| 79 |
+
except StopIteration:
|
| 80 |
+
continue
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print(f"Error reading table: {e}")
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
print(f"[WARN] Error processing DOCX ({file_path}): {e}")
|
| 87 |
+
return []
|
| 88 |
+
|
| 89 |
+
return docs
|
ingestion/loaders/md_loader.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
from typing import List
|
| 4 |
+
from langchain_core.documents import Document
|
| 5 |
+
from ingestion.loaders.normalization import normalize_text
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def load_md(file_path: str) -> List[Document]:
|
| 9 |
+
"""Load Markdown safely, preserving inline tables and skipping unreadable sections."""
|
| 10 |
+
if not os.path.exists(file_path):
|
| 11 |
+
print(f"File not found: {file_path}")
|
| 12 |
+
return []
|
| 13 |
+
|
| 14 |
+
text = ""
|
| 15 |
+
try:
|
| 16 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 17 |
+
text = f.read()
|
| 18 |
+
except UnicodeDecodeError:
|
| 19 |
+
try:
|
| 20 |
+
with open(file_path, "r", encoding="latin-1") as f:
|
| 21 |
+
text = f.read()
|
| 22 |
+
except Exception as e:
|
| 23 |
+
print(f"Failed to read Markdown file ({file_path}): {e}")
|
| 24 |
+
return []
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f"Could not open Markdown file ({file_path}): {e}")
|
| 27 |
+
return []
|
| 28 |
+
|
| 29 |
+
docs = []
|
| 30 |
+
try:
|
| 31 |
+
# Split into segments alternating between text and tables
|
| 32 |
+
parts = re.split(r"((?:\|.*\|\n)+)", text)
|
| 33 |
+
for part in parts:
|
| 34 |
+
if not part.strip():
|
| 35 |
+
continue
|
| 36 |
+
|
| 37 |
+
# Detect if segment is a table
|
| 38 |
+
content_type = "table" if re.match(r"(?:\|.*\|\n)+", part) else "text"
|
| 39 |
+
|
| 40 |
+
# Clean markdown formatting but keep structure
|
| 41 |
+
cleaned = normalize_text(re.sub(r'(```.*?```|`.*?`|\*\*|__|#)', '', part, flags=re.DOTALL))
|
| 42 |
+
if cleaned:
|
| 43 |
+
docs.append(Document(page_content=cleaned, metadata={"source": file_path, "type": content_type}))
|
| 44 |
+
except Exception as e:
|
| 45 |
+
print(f"Error parsing Markdown file ({file_path}): {e}")
|
| 46 |
+
return []
|
| 47 |
+
|
| 48 |
+
return docs
|
ingestion/loaders/normalization.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def normalize_text(text: str) -> str:
|
| 5 |
+
"""Clean and normalize extracted text from any format (PDF/DOCX/MD/TXT)."""
|
| 6 |
+
if not text:
|
| 7 |
+
return ""
|
| 8 |
+
|
| 9 |
+
# Replace common PDF CID artifacts like (cid:1234)
|
| 10 |
+
text = re.sub(r'\(cid:\d+\)', '', text)
|
| 11 |
+
|
| 12 |
+
# Replace newlines/tabs with spaces
|
| 13 |
+
text = text.replace('\n', ' ').replace('\t', ' ')
|
| 14 |
+
|
| 15 |
+
# Remove emojis and pictographs
|
| 16 |
+
emoji_pattern = re.compile(
|
| 17 |
+
"["
|
| 18 |
+
"\U0001F600-\U0001F64F" # emoticons
|
| 19 |
+
"\U0001F300-\U0001F5FF" # symbols & pictographs
|
| 20 |
+
"\U0001F680-\U0001F6FF" # transport & map
|
| 21 |
+
"\U0001F1E0-\U0001F1FF" # flags
|
| 22 |
+
"\U00002500-\U00002BEF"
|
| 23 |
+
"\U00002700-\U000027BF"
|
| 24 |
+
"\U0001F900-\U0001F9FF"
|
| 25 |
+
"\U0001FA70-\U0001FAFF"
|
| 26 |
+
"\U00002600-\U000026FF"
|
| 27 |
+
"\U00002B00-\U00002BFF"
|
| 28 |
+
"]+", flags=re.UNICODE
|
| 29 |
+
)
|
| 30 |
+
text = emoji_pattern.sub("", text)
|
| 31 |
+
|
| 32 |
+
# Collapse multiple spaces
|
| 33 |
+
text = re.sub(r'\s+', ' ', text)
|
| 34 |
+
|
| 35 |
+
return text.strip()
|
ingestion/loaders/pdf_loader.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from langchain_core.documents import Document
|
| 3 |
+
import pdfplumber
|
| 4 |
+
from ingestion.loaders.normalization import normalize_text
|
| 5 |
+
|
| 6 |
+
def load_pdf(file_path: str):
|
| 7 |
+
documents = []
|
| 8 |
+
# Check if file exists
|
| 9 |
+
if not os.path.exists(file_path):
|
| 10 |
+
raise FileNotFoundError(f"File not found: {file_path}")
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
with pdfplumber.open(file_path) as pdf:
|
| 14 |
+
for page_num, page in enumerate(pdf.pages, start=1):
|
| 15 |
+
try:
|
| 16 |
+
text = page.extract_text() or ""
|
| 17 |
+
text = normalize_text(text)
|
| 18 |
+
tables = page.extract_tables() or []
|
| 19 |
+
|
| 20 |
+
# Reconstruct page text with tables preserved in order
|
| 21 |
+
page_content = text.strip()
|
| 22 |
+
for t_idx, table in enumerate(tables, start=1):
|
| 23 |
+
table_text = "\n".join(
|
| 24 |
+
["\t".join(cell if cell else "" for cell in row) for row in table]
|
| 25 |
+
)
|
| 26 |
+
table_text = normalize_text(table_text)
|
| 27 |
+
page_content += f"\n\n=== Table {t_idx} (Page {page_num}) ===\n{table_text}"
|
| 28 |
+
|
| 29 |
+
# Append as LangChain Document
|
| 30 |
+
documents.append(
|
| 31 |
+
Document(
|
| 32 |
+
page_content=page_content,
|
| 33 |
+
metadata={
|
| 34 |
+
"source": os.path.basename(file_path),
|
| 35 |
+
"page_number": page_num,
|
| 36 |
+
},
|
| 37 |
+
)
|
| 38 |
+
)
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f"Error extracting page {page_num}: {e}")
|
| 41 |
+
continue # Skip corrupted pages, process others
|
| 42 |
+
|
| 43 |
+
except Exception as e:
|
| 44 |
+
print(f"Failed to open or read PDF file: {file_path}")
|
| 45 |
+
print(f"Error: {e}")
|
| 46 |
+
return [] # Return empty list instead of crashing
|
| 47 |
+
|
| 48 |
+
return documents
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def load_pdf_with_pages(file_path: str):
|
| 55 |
+
import fitz
|
| 56 |
+
doc = fitz.open(file_path)
|
| 57 |
+
pages = []
|
| 58 |
+
|
| 59 |
+
for i, page in enumerate(doc):
|
| 60 |
+
pages.append({
|
| 61 |
+
"page": i + 1,
|
| 62 |
+
"text": page.get_text()
|
| 63 |
+
})
|
| 64 |
+
|
| 65 |
+
return pages
|
| 66 |
+
|
ingestion/loaders/txt_loader.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
from ingestion.loaders.normalization import normalize_text
|
| 5 |
+
|
| 6 |
+
def load_txt(file_path: str) -> List[Document]:
|
| 7 |
+
"""Load plain text file safely, handling encoding issues."""
|
| 8 |
+
docs = []
|
| 9 |
+
|
| 10 |
+
if not os.path.exists(file_path):
|
| 11 |
+
print(f"File not found: {file_path}")
|
| 12 |
+
return docs
|
| 13 |
+
|
| 14 |
+
text = ""
|
| 15 |
+
try:
|
| 16 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
| 17 |
+
text = f.read()
|
| 18 |
+
except UnicodeDecodeError:
|
| 19 |
+
try:
|
| 20 |
+
with open(file_path, "r", encoding="latin-1") as f:
|
| 21 |
+
text = f.read()
|
| 22 |
+
except Exception as e:
|
| 23 |
+
print(f"Failed to read text file ({file_path}): {e}")
|
| 24 |
+
return docs
|
| 25 |
+
except Exception as e:
|
| 26 |
+
print(f"Could not open file ({file_path}): {e}")
|
| 27 |
+
return docs
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
cleaned = normalize_text(text)
|
| 31 |
+
if cleaned:
|
| 32 |
+
docs.append(
|
| 33 |
+
Document(page_content=cleaned, metadata={"source": file_path, "type": "text"})
|
| 34 |
+
)
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"Error processing text file ({file_path}): {e}")
|
| 37 |
+
|
| 38 |
+
return docs
|
ingestion/pdf_outline.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fitz
|
| 2 |
+
from config import get_settings
|
| 3 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 4 |
+
|
| 5 |
+
def extract_pdf_outline(pdf_path: str):
|
| 6 |
+
doc = fitz.open(pdf_path)
|
| 7 |
+
toc = doc.get_toc(simple=False)
|
| 8 |
+
total_pages = doc.page_count
|
| 9 |
+
|
| 10 |
+
outline = []
|
| 11 |
+
stack = []
|
| 12 |
+
for level, title, page, *_ in toc:
|
| 13 |
+
while stack and stack[-1]["level"] >= level:
|
| 14 |
+
stack.pop()
|
| 15 |
+
node = {"level": level, "title": title, "page": page, "children": []}
|
| 16 |
+
if stack:
|
| 17 |
+
stack[-1]["children"].append(node)
|
| 18 |
+
else:
|
| 19 |
+
outline.append(node)
|
| 20 |
+
stack.append(node)
|
| 21 |
+
|
| 22 |
+
doc.close()
|
| 23 |
+
return outline , total_pages
|
| 24 |
+
|
| 25 |
+
def build_page_bookmark_map(outline_tree, total_pages: int):
|
| 26 |
+
explicit_map = {}
|
| 27 |
+
|
| 28 |
+
def walk(node, path):
|
| 29 |
+
current_path = path + [node["title"]]
|
| 30 |
+
explicit_map[node["page"]] = current_path
|
| 31 |
+
for child in node["children"]:
|
| 32 |
+
walk(child, current_path)
|
| 33 |
+
|
| 34 |
+
for root in outline_tree:
|
| 35 |
+
walk(root, [])
|
| 36 |
+
|
| 37 |
+
page_map = {}
|
| 38 |
+
last_known_path = []
|
| 39 |
+
|
| 40 |
+
for page_num in range(1, total_pages + 1):
|
| 41 |
+
if page_num in explicit_map:
|
| 42 |
+
last_known_path = explicit_map[page_num]
|
| 43 |
+
page_map[page_num] = last_known_path # carries forward last bookmark
|
| 44 |
+
|
| 45 |
+
return page_map
|
| 46 |
+
|
| 47 |
+
def recursive_chunk_with_pages(pages):
|
| 48 |
+
splitter = RecursiveCharacterTextSplitter(
|
| 49 |
+
chunk_size=get_settings().CHUNK_SIZE,
|
| 50 |
+
chunk_overlap=get_settings().CHUNK_OVERLAP,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
chunks = []
|
| 54 |
+
for p in pages:
|
| 55 |
+
page_chunks = splitter.split_text(p["text"])
|
| 56 |
+
for c in page_chunks:
|
| 57 |
+
chunks.append({
|
| 58 |
+
"text": c,
|
| 59 |
+
"page": p["page"]
|
| 60 |
+
})
|
| 61 |
+
|
| 62 |
+
return chunks
|
main.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from routes.base import base_router
|
| 4 |
+
from routes.assisstant_rag import assisstant_router
|
| 5 |
+
from routes.exam_router import exam_router
|
| 6 |
+
from routes.exam_grading_router import grading_router
|
| 7 |
+
|
| 8 |
+
app = FastAPI()
|
| 9 |
+
app.add_middleware(
|
| 10 |
+
CORSMiddleware,
|
| 11 |
+
allow_origins=["*"],
|
| 12 |
+
allow_credentials=True,
|
| 13 |
+
allow_methods=["*"],
|
| 14 |
+
allow_headers=["*"],
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
app.include_router(base_router)
|
| 18 |
+
app.include_router(assisstant_router)
|
| 19 |
+
app.include_router(exam_router)
|
| 20 |
+
app.include_router(grading_router)
|
requirements.txt
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.120.0
|
| 2 |
+
uvicorn==0.38.0
|
| 3 |
+
python-dotenv==1.2.1
|
| 4 |
+
pdfplumber==0.11.7
|
| 5 |
+
python-docx==1.2.0
|
| 6 |
+
pandas==2.3.3
|
| 7 |
+
langchain==1.0.2
|
| 8 |
+
unstructured==0.18.15
|
| 9 |
+
PyMuPDF==1.26.5
|
| 10 |
+
docx2txt==0.9
|
| 11 |
+
Markdown==3.9
|
| 12 |
+
python-multipart==0.0.20
|
| 13 |
+
cohere==5.5.8
|
| 14 |
+
openai==1.35.13
|
| 15 |
+
qdrant-client== 1.16.1
|
| 16 |
+
httpx==0.28.1
|
| 17 |
+
redis==7.2.0
|
| 18 |
+
celery==5.6.2
|
| 19 |
+
json_repair==0.58.5
|
routes/__init__.py
ADDED
|
File without changes
|
routes/assisstant_rag.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter , UploadFile, File
|
| 2 |
+
from routes.schemas.Requests_Models import ChatRequest
|
| 3 |
+
from generation.AssistantRagGenerator import AssistantRagGen
|
| 4 |
+
from indexing.indexingController import IndexingController
|
| 5 |
+
from uuid import uuid4
|
| 6 |
+
from worker.tasks import process_file_task
|
| 7 |
+
from celery.result import AsyncResult
|
| 8 |
+
from celery_app import celery_app
|
| 9 |
+
|
| 10 |
+
assisstant_router = APIRouter(tags=["assistant_rag"])
|
| 11 |
+
|
| 12 |
+
@assisstant_router.get("/jobs/{job_id}")
|
| 13 |
+
def get_job_status(job_id: str):
|
| 14 |
+
result = AsyncResult(job_id, app=celery_app)
|
| 15 |
+
if result.state == "PENDING":
|
| 16 |
+
return {"job_id": job_id,"state": result.state,"message": "Job is waiting in queue",}
|
| 17 |
+
|
| 18 |
+
if result.state == "STARTED":
|
| 19 |
+
return {
|
| 20 |
+
"job_id": job_id,
|
| 21 |
+
"state": result.state,
|
| 22 |
+
"message": "Job is currently processing",
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
if result.state == "SUCCESS":
|
| 26 |
+
return {
|
| 27 |
+
"job_id": job_id,
|
| 28 |
+
"state": result.state,
|
| 29 |
+
"result": result.result,
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
if result.state == "FAILURE":
|
| 33 |
+
return {
|
| 34 |
+
"job_id": job_id,
|
| 35 |
+
"state": result.state,
|
| 36 |
+
"error": str(result.result),
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
return {
|
| 40 |
+
"job_id": job_id,
|
| 41 |
+
"state": result.state,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
@assisstant_router.post("/process-file")
|
| 45 |
+
async def process_file_endpoint(course: str , username: str , file: UploadFile = File(...)):
|
| 46 |
+
job_id = uuid4().hex
|
| 47 |
+
temp_path = f"./temp_{job_id}_{file.filename}"
|
| 48 |
+
with open(temp_path, "wb") as f:
|
| 49 |
+
f.write(await file.read())
|
| 50 |
+
task = process_file_task.delay(temp_path, file.filename, username, course)
|
| 51 |
+
return {
|
| 52 |
+
"job_id": task.id,
|
| 53 |
+
"filename": file.filename,
|
| 54 |
+
"status": "queued",
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
@assisstant_router.post("/chat/complete")
|
| 58 |
+
async def chat_complete_endpoint(request: ChatRequest):
|
| 59 |
+
indexing_controller = IndexingController()
|
| 60 |
+
rag_gen = AssistantRagGen()
|
| 61 |
+
user_query = request.prompt if request.prompt else "no question provided"
|
| 62 |
+
route = rag_gen.robust_router({"question": user_query})
|
| 63 |
+
|
| 64 |
+
results = []
|
| 65 |
+
context_text = ""
|
| 66 |
+
filters = []
|
| 67 |
+
|
| 68 |
+
# Kda Kda pdf :)
|
| 69 |
+
if request.source_file or request.bookmark:
|
| 70 |
+
if request.bookmark and not request.source_file:
|
| 71 |
+
request.bookmark=None
|
| 72 |
+
route = "pdf_query"
|
| 73 |
+
|
| 74 |
+
if route == "user_info":
|
| 75 |
+
if request.role == "instructor" or request.role == "admin":
|
| 76 |
+
context_text = (
|
| 77 |
+
f"User Profile Info: {request.user_info.model_dump()}\n"
|
| 78 |
+
f"Role: {request.role}\n"
|
| 79 |
+
f"Username: {request.username}"
|
| 80 |
+
)
|
| 81 |
+
elif request.role == "student":
|
| 82 |
+
request.user_info=request.user_info.copy(update={"instructor_owned_files": None})
|
| 83 |
+
context_text = (
|
| 84 |
+
f"User Profile Info: {request.user_info.model_dump()}\n"
|
| 85 |
+
f"Role: {request.role}\n"
|
| 86 |
+
f"Username: {request.username}"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
elif route == "site_query":
|
| 90 |
+
filters = [
|
| 91 |
+
{"field": "course", "op": "eq", "value": "Instructions", "clause": "must"},
|
| 92 |
+
{"field": "username", "op": "eq", "value": "ADMIN", "clause": "must"}
|
| 93 |
+
]
|
| 94 |
+
embedding = indexing_controller.embedder.embed_text(user_query)
|
| 95 |
+
results = indexing_controller.vector_store.query_qdrant(
|
| 96 |
+
filters=filters,
|
| 97 |
+
embedding=embedding,
|
| 98 |
+
top_k=request.top_k
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
elif route == "pdf_query":
|
| 102 |
+
if request.role == "student":
|
| 103 |
+
enrolled = request.user_info.courses or []
|
| 104 |
+
print(f"[DEBUG] Student {request.username} is enrolled in courses: {enrolled}")
|
| 105 |
+
filters.append({"field": "course", "op": "in", "value": enrolled, "clause": "must"})
|
| 106 |
+
|
| 107 |
+
elif request.role == "instructor":
|
| 108 |
+
owned = request.user_info.courses
|
| 109 |
+
# if owned == []:
|
| 110 |
+
# owned = indexing_controller.vector_store.all_user_files_bookmarks(request.username)
|
| 111 |
+
# owned = owned.keys()
|
| 112 |
+
print(f"[DEBUG] Instructor {request.username} owns courses/files: {owned}")
|
| 113 |
+
filters.append({"field": "course", "op": "in", "value": owned, "clause": "must"})
|
| 114 |
+
|
| 115 |
+
if request.source_file:
|
| 116 |
+
filters.append({"field": "source", "op": "eq", "value": request.source_file, "clause": "must"})
|
| 117 |
+
|
| 118 |
+
if request.bookmark:
|
| 119 |
+
filters.append({"field": "bookmark_path", "op": "text", "value": request.bookmark, "clause": "must"})
|
| 120 |
+
|
| 121 |
+
embedding = indexing_controller.embedder.embed_text(user_query)
|
| 122 |
+
results = indexing_controller.vector_store.query_qdrant(
|
| 123 |
+
filters=filters,
|
| 124 |
+
embedding=embedding,
|
| 125 |
+
top_k=request.top_k
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if not context_text and results:
|
| 129 |
+
context_text = "\n\n".join([r["content"] for r in results if r.get("content")])
|
| 130 |
+
|
| 131 |
+
history_str = "\n".join(
|
| 132 |
+
f"Human: {turn.Human_msg}\nAssistant: {turn.LLM_response}"
|
| 133 |
+
for turn in request.history
|
| 134 |
+
) if request.history else "None"
|
| 135 |
+
|
| 136 |
+
if route == "user_info":
|
| 137 |
+
final_prompt = rag_gen.build_user_info_prompt(
|
| 138 |
+
question=user_query,
|
| 139 |
+
conversation_history=history_str,
|
| 140 |
+
User_Info=str(request.user_info.model_dump()),
|
| 141 |
+
)
|
| 142 |
+
elif route == "site_query":
|
| 143 |
+
final_prompt = rag_gen.build_site_query_prompt(
|
| 144 |
+
question=user_query,
|
| 145 |
+
context=context_text,
|
| 146 |
+
conversation_history=history_str
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
final_prompt = rag_gen.build_unified_prompt(
|
| 150 |
+
context=context_text,
|
| 151 |
+
question=user_query,
|
| 152 |
+
conversation_history=history_str,
|
| 153 |
+
User_Info=str(request.user_info.model_dump()),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
llm_response = rag_gen.generator.generate_text(prompt=final_prompt)
|
| 157 |
+
|
| 158 |
+
return {
|
| 159 |
+
"session_id": request.session_id, # Return as is
|
| 160 |
+
"route": route,
|
| 161 |
+
"query": user_query,
|
| 162 |
+
"history": request.history, # Return as is
|
| 163 |
+
"results": results,
|
| 164 |
+
"LLM_answer": llm_response,
|
| 165 |
+
}
|
routes/base.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter , Depends
|
| 2 |
+
from config import get_settings
|
| 3 |
+
from indexing.indexingController import IndexingController
|
| 4 |
+
|
| 5 |
+
base_router = APIRouter(tags=["base"])
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@base_router.get("/health")
|
| 9 |
+
async def health_check(settings = Depends(get_settings)):
|
| 10 |
+
return {"status": "ok", "app_name": settings}
|
| 11 |
+
|
| 12 |
+
# @base_router.post("/all_docs")
|
| 13 |
+
# async def get_all_docs():
|
| 14 |
+
# indexing_controller = IndexingController()
|
| 15 |
+
# all_docs = indexing_controller.vector_store.get_all_documents()
|
| 16 |
+
# return {
|
| 17 |
+
# "total_docs": len(all_docs),
|
| 18 |
+
# "documents": all_docs
|
| 19 |
+
# }
|
| 20 |
+
|
| 21 |
+
@base_router.get("/all_files")
|
| 22 |
+
async def get__files():
|
| 23 |
+
indexing_controller = IndexingController()
|
| 24 |
+
all_files = indexing_controller.vector_store.get_all_files()
|
| 25 |
+
return {
|
| 26 |
+
"total_files": len(all_files),
|
| 27 |
+
"files": all_files,}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@base_router.get("/remove_file")
|
| 31 |
+
async def remove_file(filename: str,username: str ,course: str):
|
| 32 |
+
indexing_controller = IndexingController()
|
| 33 |
+
result = indexing_controller.vector_store.remove_points_by_file(filename,username,course)
|
| 34 |
+
return {
|
| 35 |
+
"status": "success" if result else "failure",
|
| 36 |
+
"message": f"File '{filename}' removed." if result else f"File '{filename}' not found."
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
@base_router.get("/user/docs")
|
| 40 |
+
async def get_user_docs(username: str):
|
| 41 |
+
indexing_controller = IndexingController()
|
| 42 |
+
user_docs = indexing_controller.vector_store.all_user_files_bookmarks(username)
|
| 43 |
+
return {
|
| 44 |
+
"total_docs": len(user_docs),
|
| 45 |
+
"documents": user_docs}
|
routes/exam_grading_router.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
from generation.ExamAnswer import ExamGradingService, grade_exam_task
|
| 6 |
+
from generation.answer_models import ExamSubmission, ExamResult
|
| 7 |
+
from routes.schemas.Exam_Models import ExamResponse
|
| 8 |
+
|
| 9 |
+
grading_router = APIRouter(prefix="/exam/grading", tags=["exam_grading"])
|
| 10 |
+
|
| 11 |
+
class GradingResponse(BaseModel):
|
| 12 |
+
job_id: str
|
| 13 |
+
exam_id: str
|
| 14 |
+
student_id: str
|
| 15 |
+
status: str
|
| 16 |
+
|
| 17 |
+
class GradingRequest(BaseModel):
|
| 18 |
+
submission: ExamSubmission
|
| 19 |
+
exam: ExamResponse
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def normalize_text(text: str) -> str:
|
| 23 |
+
if not text:
|
| 24 |
+
return ""
|
| 25 |
+
text = re.sub(r'[^\w\s]', '', text)
|
| 26 |
+
text = re.sub(r'\s+', ' ', text)
|
| 27 |
+
return text.strip().lower()
|
| 28 |
+
|
| 29 |
+
@grading_router.post("/submit", response_model=GradingResponse)
|
| 30 |
+
async def submit_exam(request: GradingRequest):
|
| 31 |
+
submission_dict = request.submission.model_dump()
|
| 32 |
+
exam_questions_map = {}
|
| 33 |
+
|
| 34 |
+
for q in request.exam.questions:
|
| 35 |
+
normalized_q = normalize_text(q.question)
|
| 36 |
+
exam_questions_map[normalized_q] = q
|
| 37 |
+
|
| 38 |
+
for answer in submission_dict["answers"]:
|
| 39 |
+
question_text = answer["question_text"]
|
| 40 |
+
question_type = answer["question_type"]
|
| 41 |
+
normalized_answer_text = normalize_text(question_text)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
correct_answer = None
|
| 45 |
+
if normalized_answer_text in exam_questions_map:
|
| 46 |
+
q = exam_questions_map[normalized_answer_text]
|
| 47 |
+
|
| 48 |
+
if question_type == "multiple_choice" and hasattr(q, 'correct_answer'):
|
| 49 |
+
correct_answer = q.correct_answer
|
| 50 |
+
elif question_type == "true_false" and hasattr(q, 'correct_answer'):
|
| 51 |
+
correct_answer = q.correct_answer
|
| 52 |
+
elif question_type == "short_answer" and hasattr(q, 'answer'):
|
| 53 |
+
correct_answer = q.answer
|
| 54 |
+
elif question_type == "code" and hasattr(q, 'solution'):
|
| 55 |
+
correct_answer = q.solution
|
| 56 |
+
elif question_type == "essay":
|
| 57 |
+
if hasattr(q, 'answer_guidelines') and q.answer_guidelines:
|
| 58 |
+
correct_answer = q.answer_guidelines
|
| 59 |
+
elif hasattr(q, 'answer'):
|
| 60 |
+
correct_answer = q.answer
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if "metadata" not in answer:
|
| 64 |
+
answer["metadata"] = {}
|
| 65 |
+
answer["metadata"]["correct_answer"] = correct_answer
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
task = grade_exam_task.delay(submission_dict)
|
| 69 |
+
|
| 70 |
+
return GradingResponse(
|
| 71 |
+
job_id=task.id,
|
| 72 |
+
exam_id=request.submission.exam_id,
|
| 73 |
+
student_id=request.submission.student_id,
|
| 74 |
+
status="queued"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
@grading_router.post("/grade-sync", response_model=ExamResult)
|
| 78 |
+
async def grade_sync(request: GradingRequest):
|
| 79 |
+
try:
|
| 80 |
+
service = ExamGradingService(use_ai_for_essays=True)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
exam_questions_map = {}
|
| 84 |
+
for q in request.exam.questions:
|
| 85 |
+
normalized_q = normalize_text(q.question)
|
| 86 |
+
exam_questions_map[normalized_q] = q
|
| 87 |
+
|
| 88 |
+
for ans in request.submission.answers:
|
| 89 |
+
question_text = ans.question_text
|
| 90 |
+
question_type = ans.question_type
|
| 91 |
+
normalized_answer_text = normalize_text(question_text)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
correct_answer = None
|
| 95 |
+
if normalized_answer_text in exam_questions_map:
|
| 96 |
+
q = exam_questions_map[normalized_answer_text]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if question_type == "multiple_choice" and hasattr(q, 'correct_answer'):
|
| 100 |
+
correct_answer = q.correct_answer
|
| 101 |
+
elif question_type == "true_false" and hasattr(q, 'correct_answer'):
|
| 102 |
+
correct_answer = q.correct_answer
|
| 103 |
+
elif question_type == "short_answer" and hasattr(q, 'answer'):
|
| 104 |
+
correct_answer = q.answer
|
| 105 |
+
elif question_type == "code" and hasattr(q, 'solution'):
|
| 106 |
+
correct_answer = q.solution
|
| 107 |
+
elif question_type == "essay":
|
| 108 |
+
if hasattr(q, 'answer_guidelines') and q.answer_guidelines:
|
| 109 |
+
correct_answer = q.answer_guidelines
|
| 110 |
+
elif hasattr(q, 'answer'):
|
| 111 |
+
correct_answer = q.answer
|
| 112 |
+
|
| 113 |
+
if correct_answer is not None:
|
| 114 |
+
if not ans.metadata:
|
| 115 |
+
ans.metadata = {}
|
| 116 |
+
ans.metadata["correct_answer"] = correct_answer
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
result = service.grade_submission(request.submission)
|
| 120 |
+
return result
|
| 121 |
+
except Exception as e:
|
| 122 |
+
raise HTTPException(status_code=400, detail=str(e))
|
routes/exam_router.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter
|
| 2 |
+
from routes.schemas.Exam_Models import ExamGenerationRequest
|
| 3 |
+
from worker.tasks import generate_exam_task
|
| 4 |
+
|
| 5 |
+
exam_router = APIRouter(prefix="/exam", tags=["exam"])
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@exam_router.post("/create")
|
| 9 |
+
async def process_file_endpoint(request: ExamGenerationRequest):
|
| 10 |
+
task = generate_exam_task.delay(request.model_dump())
|
| 11 |
+
return {
|
| 12 |
+
"job_id": task.id,
|
| 13 |
+
"exam_id": request.exam_id,
|
| 14 |
+
"status": "queued",
|
| 15 |
+
}
|
routes/schemas/Exam_Models.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, field_validator, model_validator
|
| 2 |
+
from typing import List, Optional, Dict
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from typing import Union
|
| 5 |
+
from typing import Literal
|
| 6 |
+
from pydantic import Field
|
| 7 |
+
from typing import Annotated
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class QuestionType(str, Enum):
|
| 11 |
+
MCQ = "mcq"
|
| 12 |
+
TRUE_FALSE = "true_false"
|
| 13 |
+
SHORT_ANSWER = "short_answer"
|
| 14 |
+
ESSAY = "essay"
|
| 15 |
+
CODE = "code"
|
| 16 |
+
|
| 17 |
+
class DifficultyLevel(str, Enum):
|
| 18 |
+
EASY = "easy"
|
| 19 |
+
MEDIUM = "medium"
|
| 20 |
+
HARD = "hard"
|
| 21 |
+
|
| 22 |
+
class Reference(BaseModel):
|
| 23 |
+
filename: str
|
| 24 |
+
bookmarks: Optional[List[str]] = None
|
| 25 |
+
|
| 26 |
+
class ExamGenerationRequest(BaseModel):
|
| 27 |
+
username: str
|
| 28 |
+
course: str
|
| 29 |
+
exam_id: str
|
| 30 |
+
total_questions: int
|
| 31 |
+
topics: List[str]
|
| 32 |
+
references: Optional[List[Reference]] = None
|
| 33 |
+
difficulty: Optional[DifficultyLevel] = DifficultyLevel.MEDIUM
|
| 34 |
+
include_answer_key: Optional[bool] = True
|
| 35 |
+
question_types_distribution: Dict[QuestionType, int]
|
| 36 |
+
model_config = {"extra": "ignore"}
|
| 37 |
+
|
| 38 |
+
@field_validator("topics")
|
| 39 |
+
@classmethod
|
| 40 |
+
def validate_topics(cls, v):
|
| 41 |
+
if not v:
|
| 42 |
+
raise ValueError("Topics cannot be empty")
|
| 43 |
+
return v
|
| 44 |
+
|
| 45 |
+
@field_validator("question_types_distribution")
|
| 46 |
+
@classmethod
|
| 47 |
+
def validate_positive(cls, v):
|
| 48 |
+
if any(count <= 0 for count in v.values()):
|
| 49 |
+
raise ValueError("All distribution counts must be > 0")
|
| 50 |
+
return v
|
| 51 |
+
|
| 52 |
+
@model_validator(mode="after")
|
| 53 |
+
def validate_sum(self):
|
| 54 |
+
if sum(self.question_types_distribution.values()) != self.total_questions:
|
| 55 |
+
raise ValueError("Distribution must equal total_questions")
|
| 56 |
+
return self
|
| 57 |
+
|
| 58 |
+
class QuestionBase(BaseModel):
|
| 59 |
+
type: QuestionType
|
| 60 |
+
question: str
|
| 61 |
+
model_config = {"extra": "ignore"}
|
| 62 |
+
|
| 63 |
+
class MCQQuestion(QuestionBase):
|
| 64 |
+
type: Literal[QuestionType.MCQ]
|
| 65 |
+
options: List[str]
|
| 66 |
+
correct_answer: str
|
| 67 |
+
explanation: str
|
| 68 |
+
|
| 69 |
+
@model_validator(mode="after")
|
| 70 |
+
def validate_mcq(self):
|
| 71 |
+
if len(self.options) < 2:
|
| 72 |
+
raise ValueError("MCQ must contain at least 2 options")
|
| 73 |
+
if self.correct_answer not in self.options:
|
| 74 |
+
raise ValueError("correct_answer must exist in options")
|
| 75 |
+
return self
|
| 76 |
+
|
| 77 |
+
class TrueFalseQuestion(QuestionBase):
|
| 78 |
+
type: Literal[QuestionType.TRUE_FALSE]
|
| 79 |
+
correct_answer: bool
|
| 80 |
+
explanation: str
|
| 81 |
+
|
| 82 |
+
class ShortAnswerQuestion(QuestionBase):
|
| 83 |
+
type: Literal[QuestionType.SHORT_ANSWER]
|
| 84 |
+
answer: str
|
| 85 |
+
explanation: str
|
| 86 |
+
|
| 87 |
+
class EssayQuestion(QuestionBase):
|
| 88 |
+
type: Literal[QuestionType.ESSAY]
|
| 89 |
+
answer: str
|
| 90 |
+
answer_guidelines: str
|
| 91 |
+
|
| 92 |
+
class CodeQuestion(QuestionBase):
|
| 93 |
+
type: Literal[QuestionType.CODE]
|
| 94 |
+
|
| 95 |
+
starter_code: Optional[str] = Field(
|
| 96 |
+
default=None,
|
| 97 |
+
description="Starter code shown to the student"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
language: str= "c"
|
| 101 |
+
|
| 102 |
+
solution: str = Field(
|
| 103 |
+
description="Correct full solution code"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
explanation: str = Field(
|
| 107 |
+
description="Explanation of how the solution works"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
@field_validator("starter_code", "solution")
|
| 111 |
+
@classmethod
|
| 112 |
+
def normalize_code(cls, v):
|
| 113 |
+
"""Convert escaped newlines to real newlines if present."""
|
| 114 |
+
if v:
|
| 115 |
+
return v.replace("\\n", "\n")
|
| 116 |
+
return v
|
| 117 |
+
|
| 118 |
+
QuestionUnion = Annotated[
|
| 119 |
+
Union[
|
| 120 |
+
MCQQuestion,
|
| 121 |
+
TrueFalseQuestion,
|
| 122 |
+
ShortAnswerQuestion,
|
| 123 |
+
EssayQuestion,
|
| 124 |
+
CodeQuestion,
|
| 125 |
+
],
|
| 126 |
+
Field(discriminator="type"),
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
class ExamResponse(BaseModel):
|
| 130 |
+
exam_id: str
|
| 131 |
+
difficulty: DifficultyLevel
|
| 132 |
+
total_questions: int
|
| 133 |
+
questions: List[QuestionUnion]
|
| 134 |
+
expected_distribution: Dict[QuestionType, int]
|
| 135 |
+
model_config = {"extra": "ignore"}
|
| 136 |
+
|
| 137 |
+
@model_validator(mode="after")
|
| 138 |
+
def validate_question_count(self):
|
| 139 |
+
if len(self.questions) != self.total_questions:
|
| 140 |
+
raise ValueError(
|
| 141 |
+
f"Expected {self.total_questions} questions, "
|
| 142 |
+
f"but got {len(self.questions)}"
|
| 143 |
+
)
|
| 144 |
+
return self
|
| 145 |
+
@model_validator(mode="after")
|
| 146 |
+
def validate_distribution(self):
|
| 147 |
+
|
| 148 |
+
actual_counts: Dict[QuestionType, int] = {}
|
| 149 |
+
|
| 150 |
+
for q in self.questions:
|
| 151 |
+
actual_counts[q.type] = actual_counts.get(q.type, 0) + 1
|
| 152 |
+
|
| 153 |
+
if set(actual_counts.keys()) != set(self.expected_distribution.keys()):
|
| 154 |
+
raise ValueError("Unexpected question types in exam")
|
| 155 |
+
|
| 156 |
+
for q_type, expected_count in self.expected_distribution.items():
|
| 157 |
+
actual = actual_counts.get(q_type, 0)
|
| 158 |
+
|
| 159 |
+
if actual != expected_count:
|
| 160 |
+
raise ValueError(
|
| 161 |
+
f"Distribution mismatch for {q_type.value}: "
|
| 162 |
+
f"expected {expected_count}, got {actual}"
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
return self
|
| 166 |
+
|
| 167 |
+
class AnswerItem(BaseModel):
|
| 168 |
+
question_index: int
|
| 169 |
+
answer: str
|
| 170 |
+
|
| 171 |
+
class AnswerKey(BaseModel):
|
| 172 |
+
exam_id: str
|
| 173 |
+
answers: List[AnswerItem]
|
| 174 |
+
model_config = {"extra": "ignore"}
|
| 175 |
+
|
| 176 |
+
class EvaluationResult(BaseModel):
|
| 177 |
+
overall_score: int
|
| 178 |
+
feedback: str
|
| 179 |
+
model_config = {"extra": "ignore"}
|
| 180 |
+
|
routes/schemas/Requests_Models.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import Optional, List
|
| 3 |
+
|
| 4 |
+
class ConversationTurn(BaseModel):
|
| 5 |
+
Human_msg: str
|
| 6 |
+
LLM_response: str
|
| 7 |
+
|
| 8 |
+
class UserInfoRequest(BaseModel):
|
| 9 |
+
courses: Optional[List[str]] = None
|
| 10 |
+
deadlines: Optional[List[str]] = None
|
| 11 |
+
grades: Optional[List[str]] = None
|
| 12 |
+
instructor_owned_files: Optional[List[str]] = None
|
| 13 |
+
more_info: Optional[str] = None
|
| 14 |
+
|
| 15 |
+
class ChatRequest(BaseModel):
|
| 16 |
+
prompt: Optional[str] = None
|
| 17 |
+
username: str
|
| 18 |
+
session_id: str
|
| 19 |
+
role: str
|
| 20 |
+
top_k: int = 5
|
| 21 |
+
source_file: Optional[str] = None
|
| 22 |
+
bookmark: Optional[str] = None
|
| 23 |
+
history: Optional[List[ConversationTurn]] = None
|
| 24 |
+
user_info: Optional[UserInfoRequest]= None
|
routes/schemas/__init__.py
ADDED
|
File without changes
|
stores/llm/LLMEnums.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
class LLMEnums(Enum):
|
| 4 |
+
OPENAI = "OPENAI"
|
| 5 |
+
COHERE = "COHERE"
|
| 6 |
+
OLLAMA = "OLLAMA"
|
| 7 |
+
MISTRAL = "MISTRAL"
|
| 8 |
+
GROQ = "GROQ"
|
| 9 |
+
OPENROUTER = "OPENROUTER"
|
| 10 |
+
HUGGINGFACE = "HUGGINGFACE"
|
| 11 |
+
DEEPSEEK = "DEEPSEEK"
|
| 12 |
+
GEMINI = "GEMINI"
|
| 13 |
+
|
| 14 |
+
class OpenAIEnums(Enum):
|
| 15 |
+
SYSTEM = "system"
|
| 16 |
+
USER = "user"
|
| 17 |
+
ASSISTANT = "assistant"
|
| 18 |
+
|
| 19 |
+
class CoHereEnums(Enum):
|
| 20 |
+
SYSTEM = "SYSTEM"
|
| 21 |
+
USER = "USER"
|
| 22 |
+
ASSISTANT = "CHATBOT"
|
| 23 |
+
DOCUMENT = "search_document"
|
| 24 |
+
QUERY = "search_query"
|
| 25 |
+
|
| 26 |
+
class DocumentTypeEnum(Enum):
|
| 27 |
+
DOCUMENT = "document"
|
| 28 |
+
QUERY = "query"
|
stores/llm/LLMInterface.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
class LLMInterface(ABC):
|
| 4 |
+
|
| 5 |
+
@abstractmethod
|
| 6 |
+
def set_generation_model(self, model_id: str):
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
@abstractmethod
|
| 10 |
+
def set_embedding_model(self, model_id: str, embedding_size: int):
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
@abstractmethod
|
| 14 |
+
def generate_text(self, prompt: str, chat_history: list=[], max_output_tokens: int=None,
|
| 15 |
+
temperature: float = None):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
@abstractmethod
|
| 19 |
+
def embed_text_batch(self, texts: list[str], batch_size: int = 32):
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def construct_prompt(self, prompt: str, role: str):
|
| 24 |
+
pass
|
stores/llm/LLMProviderFactory.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .LLMEnums import LLMEnums
|
| 2 |
+
from stores.llm.providers.OpenAIProvider import OpenAIProvider
|
| 3 |
+
from stores.llm.providers.OllamaProvider import OllamaProvider
|
| 4 |
+
from stores.llm.providers.CohereProvider import CohereProvider
|
| 5 |
+
from stores.llm.providers.MistralProvider import MistralProvider
|
| 6 |
+
from stores.llm.providers.GroqProvider import GroqProvider
|
| 7 |
+
from stores.llm.providers.OpenRouterProvider import OpenRouterProvider
|
| 8 |
+
from stores.llm.providers.HuggingFaceProvider import HuggingFaceProvider
|
| 9 |
+
from stores.llm.providers.DeepSeekProvider import DeepSeekProvider
|
| 10 |
+
from stores.llm.providers.GeminiProvider import GeminiProvider
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LLMProviderFactory:
|
| 14 |
+
def __init__(self, config: dict):
|
| 15 |
+
self.config = config
|
| 16 |
+
|
| 17 |
+
def create(self, provider: str):
|
| 18 |
+
|
| 19 |
+
if provider == LLMEnums.OPENAI.value:
|
| 20 |
+
return OpenAIProvider(
|
| 21 |
+
api_key=self.config.OPENAI_API_KEY,
|
| 22 |
+
api_url=self.config.OPENAI_API_URL,
|
| 23 |
+
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
|
| 24 |
+
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
|
| 25 |
+
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
if provider == LLMEnums.OLLAMA.value:
|
| 29 |
+
return OllamaProvider(
|
| 30 |
+
url=self.config.OLLAMA_URL,
|
| 31 |
+
api_key=self.config.OLLAMA_API_KEY,
|
| 32 |
+
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
|
| 33 |
+
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
|
| 34 |
+
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
if provider == LLMEnums.COHERE.value:
|
| 38 |
+
return CohereProvider(
|
| 39 |
+
api_key=self.config.COHERE_API_KEY,
|
| 40 |
+
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
|
| 41 |
+
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
|
| 42 |
+
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
if provider == LLMEnums.MISTRAL.value:
|
| 46 |
+
return MistralProvider(
|
| 47 |
+
api_key=self.config.MISTRAL_API_KEY,
|
| 48 |
+
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
|
| 49 |
+
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
|
| 50 |
+
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
if provider == LLMEnums.GROQ.value:
|
| 54 |
+
return GroqProvider(
|
| 55 |
+
api_key=self.config.GROQ_API_KEY,
|
| 56 |
+
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
|
| 57 |
+
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
|
| 58 |
+
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if provider == LLMEnums.OPENROUTER.value:
|
| 62 |
+
return OpenRouterProvider(
|
| 63 |
+
api_key=self.config.OPENROUTER_API_KEY,
|
| 64 |
+
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
|
| 65 |
+
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
|
| 66 |
+
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
if provider == LLMEnums.HUGGINGFACE.value:
|
| 70 |
+
return HuggingFaceProvider(
|
| 71 |
+
api_key=self.config.HF_API_KEY,
|
| 72 |
+
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
|
| 73 |
+
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
|
| 74 |
+
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
if provider == LLMEnums.DEEPSEEK.value:
|
| 78 |
+
return DeepSeekProvider(
|
| 79 |
+
api_key=self.config.DEEPSEEK_API_KEY,
|
| 80 |
+
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
|
| 81 |
+
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
|
| 82 |
+
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
if provider == LLMEnums.GEMINI.value:
|
| 86 |
+
return GeminiProvider(
|
| 87 |
+
api_key=self.config.GEMINI_API_KEY,
|
| 88 |
+
default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
|
| 89 |
+
default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
|
| 90 |
+
default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
return None
|
stores/llm/__init__.py
ADDED
|
File without changes
|
stores/llm/providers/CohereProvider.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stores.llm.LLMInterface import LLMInterface
|
| 2 |
+
import logging
|
| 3 |
+
import requests
|
| 4 |
+
import re
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
import math
|
| 8 |
+
class CohereProvider(LLMInterface):
|
| 9 |
+
def __init__(self, url: str = None, model: str = None,
|
| 10 |
+
default_input_max_characters: int = 1000,
|
| 11 |
+
default_generation_max_output_tokens: int = 1000,
|
| 12 |
+
default_generation_temperature: float = 0.1, api_key: str = None):
|
| 13 |
+
self.url = url or "https://api.cohere.com/v2"
|
| 14 |
+
self.api_key = api_key or os.getenv("COHERE_API_KEY")
|
| 15 |
+
self.model = model
|
| 16 |
+
self.generation_model_id = None
|
| 17 |
+
|
| 18 |
+
self.embedding_model = None
|
| 19 |
+
self.embedding_model_id = None
|
| 20 |
+
self.embedding_size = None
|
| 21 |
+
|
| 22 |
+
self.default_input_max_characters = default_input_max_characters
|
| 23 |
+
self.default_generation_max_output_tokens = default_generation_max_output_tokens
|
| 24 |
+
self.default_generation_temperature = default_generation_temperature
|
| 25 |
+
self.logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
def set_generation_model(self, model_id: str):
|
| 28 |
+
if model_id:
|
| 29 |
+
self.model = model_id
|
| 30 |
+
|
| 31 |
+
def set_embedding_model(self, model_id: str, embedding_size: int):
|
| 32 |
+
if model_id:
|
| 33 |
+
self.embedding_model = model_id
|
| 34 |
+
self.embedding_size = embedding_size
|
| 35 |
+
self.embedding_model_id = model_id
|
| 36 |
+
|
| 37 |
+
def process_text(self, text: str):
|
| 38 |
+
if not text:
|
| 39 |
+
return ""
|
| 40 |
+
return str(text).strip()
|
| 41 |
+
|
| 42 |
+
def generate_text(self, prompt: str, chat_history: list = None,
|
| 43 |
+
max_output_tokens: int = None, temperature: float = None):
|
| 44 |
+
try:
|
| 45 |
+
chat_history = chat_history or [] # safe handling
|
| 46 |
+
clean_prompt = self.process_text(prompt)
|
| 47 |
+
|
| 48 |
+
# Build messages list from chat_history + current prompt
|
| 49 |
+
messages = []
|
| 50 |
+
for entry in chat_history:
|
| 51 |
+
messages.append({
|
| 52 |
+
"role": entry.get("role", "user"),
|
| 53 |
+
"content": entry.get("content", "")
|
| 54 |
+
})
|
| 55 |
+
messages.append({"role": "user", "content": clean_prompt})
|
| 56 |
+
|
| 57 |
+
payload = {
|
| 58 |
+
"model": self.model,
|
| 59 |
+
"messages": messages,
|
| 60 |
+
"max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens),
|
| 61 |
+
"temperature": float(temperature or self.default_generation_temperature),
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
url = self.url.rstrip("/") + "/chat"
|
| 65 |
+
headers = {
|
| 66 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 67 |
+
"Content-Type": "application/json",
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=6000)
|
| 71 |
+
if resp.status_code != 200:
|
| 72 |
+
self.logger.error("Cohere generate failed: %s %s", resp.status_code, resp.text)
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
data = resp.json()
|
| 76 |
+
|
| 77 |
+
# Extract generated text from Cohere v2 chat response
|
| 78 |
+
generated_text = ""
|
| 79 |
+
try:
|
| 80 |
+
generated_text = data["message"]["content"][0]["text"].strip()
|
| 81 |
+
except (KeyError, IndexError, TypeError):
|
| 82 |
+
self.logger.error("Unexpected Cohere response structure: %s", data)
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
if not generated_text:
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
# Mirror the same return shape as OllamaProvider
|
| 89 |
+
usage = data.get("usage", {})
|
| 90 |
+
return {
|
| 91 |
+
"model": data.get("model"),
|
| 92 |
+
"response": generated_text,
|
| 93 |
+
"tokens_generated": usage.get("tokens", {}).get("output_tokens"),
|
| 94 |
+
"total_duration_ms": None, # Cohere does not expose latency in response
|
| 95 |
+
"prompt_eval_tokens": usage.get("tokens", {}).get("input_tokens"),
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
except Exception as e:
|
| 99 |
+
self.logger.exception("Error in CohereProvider.generate_text: %s", e)
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
def embed_text(self, text: str, document_type: str = None):
|
| 103 |
+
"""Return an embedding vector from Cohere."""
|
| 104 |
+
try:
|
| 105 |
+
if not self.embedding_model:
|
| 106 |
+
self.logger.error("Embedding model is not set before calling embed_text()")
|
| 107 |
+
return None
|
| 108 |
+
|
| 109 |
+
clean_text = self.process_text(text)
|
| 110 |
+
print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'")
|
| 111 |
+
if not clean_text:
|
| 112 |
+
return []
|
| 113 |
+
|
| 114 |
+
# Cohere requires an input_type; map document_type or fall back to "search_document"
|
| 115 |
+
input_type = document_type if document_type in (
|
| 116 |
+
"search_document", "search_query", "classification", "clustering"
|
| 117 |
+
) else "search_document"
|
| 118 |
+
|
| 119 |
+
payload = {
|
| 120 |
+
"model": self.embedding_model,
|
| 121 |
+
"texts": [clean_text],
|
| 122 |
+
"input_type": input_type,
|
| 123 |
+
"embedding_types": ["float"],
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
url = self.url.rstrip("/") + "/embed"
|
| 127 |
+
headers = {
|
| 128 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 129 |
+
"Content-Type": "application/json",
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=200)
|
| 133 |
+
if resp.status_code != 200:
|
| 134 |
+
print(f"[ERROR] Cohere embedding failed: {resp.status_code} {resp.text}")
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
data = resp.json()
|
| 138 |
+
|
| 139 |
+
# Cohere v2 returns embeddings under data.embeddings.float
|
| 140 |
+
embedding = None
|
| 141 |
+
try:
|
| 142 |
+
embedding = data["embeddings"]["float"][0]
|
| 143 |
+
except (KeyError, IndexError, TypeError):
|
| 144 |
+
pass
|
| 145 |
+
|
| 146 |
+
# Fallback: older v1-style shape
|
| 147 |
+
if embedding is None:
|
| 148 |
+
try:
|
| 149 |
+
embedding = data["embeddings"][0]
|
| 150 |
+
except (KeyError, IndexError, TypeError):
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
if embedding is not None:
|
| 154 |
+
print(f"[DEBUG] Embedding length: {len(embedding)}")
|
| 155 |
+
return embedding
|
| 156 |
+
|
| 157 |
+
print("[WARNING] 'embedding' key not found in response JSON")
|
| 158 |
+
return None
|
| 159 |
+
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print(f"[EXCEPTION] Error in CohereProvider.embed_text: {e}")
|
| 162 |
+
return None
|
| 163 |
+
|
| 164 |
+
def construct_prompt(self, prompt: str, role: str):
|
| 165 |
+
return {
|
| 166 |
+
"role": role,
|
| 167 |
+
"content": self.process_text(prompt)
|
| 168 |
+
}
|
| 169 |
+
def embed_text_batch(self, texts: list[str], batch_size: int = 96):
|
| 170 |
+
self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}")
|
| 171 |
+
|
| 172 |
+
if not self.embedding_model:
|
| 173 |
+
self.logger.error("Embedding model not set")
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
all_embeddings = []
|
| 177 |
+
total_batches = math.ceil(len(texts) / batch_size)
|
| 178 |
+
|
| 179 |
+
url = self.url.rstrip("/") + "/embed"
|
| 180 |
+
headers = {
|
| 181 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 182 |
+
"Content-Type": "application/json",
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
# Cohere free tier: 10 req/min | paid: 100 req/min
|
| 186 |
+
# Adjust MIN_SECONDS_PER_REQUEST to match your plan
|
| 187 |
+
MIN_SECONDS_PER_REQUEST = 0.65 # ~92 req/min (safe under 100/min paid)
|
| 188 |
+
MAX_RETRIES = 5
|
| 189 |
+
BACKOFF_BASE = 10 # seconds — doubles on each retry
|
| 190 |
+
|
| 191 |
+
for batch_idx, i in enumerate(range(0, len(texts), batch_size), start=1):
|
| 192 |
+
time.sleep(6)
|
| 193 |
+
batch = texts[i:i + batch_size]
|
| 194 |
+
clean_batch = [self.process_text(t) for t in batch if t]
|
| 195 |
+
|
| 196 |
+
# ── Progress ────────────────────────────────────────────────────────
|
| 197 |
+
done_texts = min(i + batch_size, len(texts))
|
| 198 |
+
pct = (batch_idx / total_batches) * 100
|
| 199 |
+
bar_filled = int(pct / 5) # 20-char bar
|
| 200 |
+
bar = "█" * bar_filled + "░" * (20 - bar_filled)
|
| 201 |
+
print(
|
| 202 |
+
f"\r[EMBED] [{bar}] {pct:5.1f}% "
|
| 203 |
+
f"batch {batch_idx}/{total_batches} "
|
| 204 |
+
f"({done_texts}/{len(texts)} texts)",
|
| 205 |
+
end="", flush=True
|
| 206 |
+
)
|
| 207 |
+
# ────────────────────────────────────────────────────────────────────
|
| 208 |
+
|
| 209 |
+
payload = {
|
| 210 |
+
"model": self.embedding_model,
|
| 211 |
+
"texts": clean_batch,
|
| 212 |
+
"input_type": "search_document",
|
| 213 |
+
"embedding_types": ["float"],
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
# ── Rate-limited request with exponential back-off ──────────────────
|
| 217 |
+
embeddings = None
|
| 218 |
+
request_start = time.monotonic()
|
| 219 |
+
|
| 220 |
+
for attempt in range(1, MAX_RETRIES + 1):
|
| 221 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=200)
|
| 222 |
+
|
| 223 |
+
if resp.status_code == 200:
|
| 224 |
+
break
|
| 225 |
+
|
| 226 |
+
if resp.status_code == 429:
|
| 227 |
+
retry_after = float(resp.headers.get("Retry-After", BACKOFF_BASE ** attempt))
|
| 228 |
+
print(
|
| 229 |
+
f"\n[RATE LIMIT] batch {batch_idx} — "
|
| 230 |
+
f"attempt {attempt}/{MAX_RETRIES}, "
|
| 231 |
+
f"waiting {retry_after:.1f}s …"
|
| 232 |
+
)
|
| 233 |
+
time.sleep(retry_after)
|
| 234 |
+
continue
|
| 235 |
+
|
| 236 |
+
# Any other non-200 — log and abort
|
| 237 |
+
self.logger.error(
|
| 238 |
+
"Cohere embedding failed (batch %d, attempt %d): %s %s",
|
| 239 |
+
batch_idx, attempt, resp.status_code, resp.text
|
| 240 |
+
)
|
| 241 |
+
return None
|
| 242 |
+
|
| 243 |
+
else:
|
| 244 |
+
# Exhausted all retries on 429
|
| 245 |
+
self.logger.error(
|
| 246 |
+
"Cohere embedding: max retries (%d) exceeded on batch %d",
|
| 247 |
+
MAX_RETRIES, batch_idx
|
| 248 |
+
)
|
| 249 |
+
return None
|
| 250 |
+
|
| 251 |
+
# ── Parse response ──────────────────────────────────────────────────
|
| 252 |
+
data = resp.json()
|
| 253 |
+
|
| 254 |
+
try:
|
| 255 |
+
embeddings = data["embeddings"]["float"] # v2 shape
|
| 256 |
+
except (KeyError, TypeError):
|
| 257 |
+
embeddings = data.get("embeddings") # v1 shape
|
| 258 |
+
|
| 259 |
+
if not embeddings:
|
| 260 |
+
self.logger.error("No embeddings returned from Cohere (batch %d)", batch_idx)
|
| 261 |
+
return None
|
| 262 |
+
|
| 263 |
+
self.logger.debug(f"Received {len(embeddings)} embeddings for batch {batch_idx}")
|
| 264 |
+
all_embeddings.extend(embeddings)
|
| 265 |
+
|
| 266 |
+
# ── Pace requests to stay under rate limit ──────────────────────────
|
| 267 |
+
elapsed = time.monotonic() - request_start
|
| 268 |
+
sleep_for = max(0.0, MIN_SECONDS_PER_REQUEST - elapsed)
|
| 269 |
+
if sleep_for > 0:
|
| 270 |
+
time.sleep(sleep_for)
|
| 271 |
+
# ────────────────────────────────────────────────────────────────────
|
| 272 |
+
|
| 273 |
+
# Final newline after the progress bar
|
| 274 |
+
print(f"\r[EMBED] [{'█' * 20}] 100.0% "
|
| 275 |
+
f"batch {total_batches}/{total_batches} "
|
| 276 |
+
f"({len(texts)}/{len(texts)} texts) ✓")
|
| 277 |
+
|
| 278 |
+
self.logger.info(f"Total embeddings created: {len(all_embeddings)}")
|
| 279 |
+
return all_embeddings
|
| 280 |
+
# def embed_text_batch(self, texts: list[str], batch_size: int = 32):
|
| 281 |
+
# self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}")
|
| 282 |
+
|
| 283 |
+
# if not self.embedding_model:
|
| 284 |
+
# self.logger.error("Embedding model not set")
|
| 285 |
+
# return None
|
| 286 |
+
|
| 287 |
+
# all_embeddings = []
|
| 288 |
+
|
| 289 |
+
# url = self.url.rstrip("/") + "/embed"
|
| 290 |
+
# headers = {
|
| 291 |
+
# "Authorization": f"Bearer {self.api_key}",
|
| 292 |
+
# "Content-Type": "application/json",
|
| 293 |
+
# }
|
| 294 |
+
|
| 295 |
+
# for i in range(0, len(texts), batch_size):
|
| 296 |
+
# batch = texts[i:i + batch_size]
|
| 297 |
+
# clean_batch = [self.process_text(t) for t in batch if t]
|
| 298 |
+
|
| 299 |
+
# print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size}")
|
| 300 |
+
|
| 301 |
+
# payload = {
|
| 302 |
+
# "model": self.embedding_model,
|
| 303 |
+
# "texts": clean_batch,
|
| 304 |
+
# "input_type": "search_document",
|
| 305 |
+
# "embedding_types": ["float"],
|
| 306 |
+
# }
|
| 307 |
+
|
| 308 |
+
# resp = requests.post(url, json=payload, headers=headers, timeout=200)
|
| 309 |
+
# if resp.status_code != 200:
|
| 310 |
+
# self.logger.error("Cohere embedding failed: %s %s", resp.status_code, resp.text)
|
| 311 |
+
# return None
|
| 312 |
+
|
| 313 |
+
# data = resp.json()
|
| 314 |
+
|
| 315 |
+
# # Handle both v2 (embeddings.float) and v1 (embeddings) shapes
|
| 316 |
+
# embeddings = None
|
| 317 |
+
# try:
|
| 318 |
+
# embeddings = data["embeddings"]["float"]
|
| 319 |
+
# except (KeyError, TypeError):
|
| 320 |
+
# embeddings = data.get("embeddings")
|
| 321 |
+
|
| 322 |
+
# if not embeddings:
|
| 323 |
+
# self.logger.error("No embeddings returned from Cohere")
|
| 324 |
+
# return None
|
| 325 |
+
|
| 326 |
+
# self.logger.debug(f"Received {len(embeddings)} embeddings")
|
| 327 |
+
# all_embeddings.extend(embeddings)
|
| 328 |
+
|
| 329 |
+
# self.logger.info(f"Total embeddings created: {len(all_embeddings)}")
|
| 330 |
+
# return all_embeddings
|
| 331 |
+
|
| 332 |
+
def clean_content(self, text: str) -> str:
|
| 333 |
+
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
| 334 |
+
text = re.sub(r'\[[^\]]*\]', '', text)
|
| 335 |
+
text = re.sub(r'\n+', '\n', text).strip()
|
| 336 |
+
return text
|
| 337 |
+
|
| 338 |
+
def web_search(self, query: str):
|
| 339 |
+
"""Use Cohere's chat endpoint with web-search connector to perform a search."""
|
| 340 |
+
try:
|
| 341 |
+
payload = {
|
| 342 |
+
"model": self.model,
|
| 343 |
+
"messages": [{"role": "user", "content": query}],
|
| 344 |
+
"tools": [{"type": "web_search"}],
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
url = self.url.rstrip("/") + "/chat"
|
| 348 |
+
headers = {
|
| 349 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 350 |
+
"Content-Type": "application/json",
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=6000)
|
| 354 |
+
|
| 355 |
+
if not resp or resp.status_code != 200:
|
| 356 |
+
return {
|
| 357 |
+
"text": "No relevant external results found.",
|
| 358 |
+
"references": []
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
data = resp.json()
|
| 362 |
+
|
| 363 |
+
combined_text = []
|
| 364 |
+
references = set()
|
| 365 |
+
|
| 366 |
+
# Extract assistant text
|
| 367 |
+
try:
|
| 368 |
+
assistant_text = data["message"]["content"][0]["text"]
|
| 369 |
+
combined_text.append(self.clean_content(assistant_text))
|
| 370 |
+
except (KeyError, IndexError, TypeError):
|
| 371 |
+
pass
|
| 372 |
+
|
| 373 |
+
# Extract citations / source URLs from Cohere's citations block
|
| 374 |
+
for citation in data.get("message", {}).get("citations", []):
|
| 375 |
+
for source in citation.get("sources", []):
|
| 376 |
+
url_val = source.get("url") or source.get("id", "")
|
| 377 |
+
if url_val.startswith("http"):
|
| 378 |
+
references.add(url_val)
|
| 379 |
+
|
| 380 |
+
# Also scan raw text for bare URLs (mirrors Ollama behaviour)
|
| 381 |
+
raw_text = "\n".join(combined_text)
|
| 382 |
+
for found_url in re.findall(r"https?://[^\s)]+", raw_text):
|
| 383 |
+
references.add(found_url)
|
| 384 |
+
|
| 385 |
+
return {
|
| 386 |
+
"text": "\n\n".join(combined_text[:3]),
|
| 387 |
+
"references": list(references)
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
except Exception as e:
|
| 391 |
+
self.logger.error("Cohere web search failed: %s", e)
|
| 392 |
+
return {
|
| 393 |
+
"text": f"Cohere search error: {str(e)}",
|
| 394 |
+
"references": []
|
| 395 |
+
}
|
stores/llm/providers/DeepSeekProvider.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stores.llm.LLMInterface import LLMInterface
|
| 2 |
+
import logging
|
| 3 |
+
import requests
|
| 4 |
+
import re
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DeepSeekProvider(LLMInterface):
|
| 9 |
+
def __init__(self, url: str = None, model: str = None,
|
| 10 |
+
default_input_max_characters: int = 1000,
|
| 11 |
+
default_generation_max_output_tokens: int = 1000,
|
| 12 |
+
default_generation_temperature: float = 0.1, api_key: str = None):
|
| 13 |
+
self.url = url or "https://api.deepseek.com/v1"
|
| 14 |
+
self.api_key = api_key or os.getenv("DEEPSEEK_API_KEY")
|
| 15 |
+
self.model = model
|
| 16 |
+
self.generation_model_id = None
|
| 17 |
+
|
| 18 |
+
self.embedding_model = None
|
| 19 |
+
self.embedding_model_id = None
|
| 20 |
+
self.embedding_size = None
|
| 21 |
+
|
| 22 |
+
self.default_input_max_characters = default_input_max_characters
|
| 23 |
+
self.default_generation_max_output_tokens = default_generation_max_output_tokens
|
| 24 |
+
self.default_generation_temperature = default_generation_temperature
|
| 25 |
+
self.logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
def set_generation_model(self, model_id: str):
|
| 28 |
+
if model_id:
|
| 29 |
+
self.model = model_id
|
| 30 |
+
|
| 31 |
+
def set_embedding_model(self, model_id: str, embedding_size: int):
|
| 32 |
+
if model_id:
|
| 33 |
+
self.embedding_model = model_id
|
| 34 |
+
self.embedding_size = embedding_size
|
| 35 |
+
self.embedding_model_id = model_id
|
| 36 |
+
|
| 37 |
+
def process_text(self, text: str):
|
| 38 |
+
if not text:
|
| 39 |
+
return ""
|
| 40 |
+
return str(text).strip()
|
| 41 |
+
|
| 42 |
+
def generate_text(self, prompt: str, chat_history: list = None,
|
| 43 |
+
max_output_tokens: int = None, temperature: float = None):
|
| 44 |
+
try:
|
| 45 |
+
chat_history = chat_history or []
|
| 46 |
+
clean_prompt = self.process_text(prompt)
|
| 47 |
+
|
| 48 |
+
messages = []
|
| 49 |
+
for entry in chat_history:
|
| 50 |
+
messages.append({
|
| 51 |
+
"role": entry.get("role", "user"),
|
| 52 |
+
"content": entry.get("content", "")
|
| 53 |
+
})
|
| 54 |
+
messages.append({"role": "user", "content": clean_prompt})
|
| 55 |
+
|
| 56 |
+
payload = {
|
| 57 |
+
"model": self.model,
|
| 58 |
+
"messages": messages,
|
| 59 |
+
"max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens),
|
| 60 |
+
"temperature": float(temperature or self.default_generation_temperature),
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
url = self.url.rstrip("/") + "/chat/completions"
|
| 64 |
+
headers = {
|
| 65 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 66 |
+
"Content-Type": "application/json",
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=6000)
|
| 70 |
+
if resp.status_code != 200:
|
| 71 |
+
self.logger.error("DeepSeek generate failed: %s %s", resp.status_code, resp.text)
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
data = resp.json()
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
generated_text = data["choices"][0]["message"]["content"].strip()
|
| 78 |
+
except (KeyError, IndexError, TypeError):
|
| 79 |
+
self.logger.error("Unexpected DeepSeek response structure: %s", data)
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
if not generated_text:
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
usage = data.get("usage", {})
|
| 86 |
+
return {
|
| 87 |
+
"model": data.get("model"),
|
| 88 |
+
"response": generated_text,
|
| 89 |
+
"tokens_generated": usage.get("completion_tokens"),
|
| 90 |
+
"total_duration_ms": None,
|
| 91 |
+
"prompt_eval_tokens": usage.get("prompt_tokens"),
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
self.logger.exception("Error in DeepSeekProvider.generate_text: %s", e)
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
def embed_text(self, text: str, document_type: str = None):
|
| 99 |
+
"""DeepSeek does not currently offer an embeddings endpoint — returns None."""
|
| 100 |
+
self.logger.warning("DeepSeekProvider does not support embeddings.")
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
def construct_prompt(self, prompt: str, role: str):
|
| 104 |
+
return {
|
| 105 |
+
"role": role,
|
| 106 |
+
"content": self.process_text(prompt)
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
def embed_text_batch(self, texts: list[str], batch_size: int = 32):
|
| 110 |
+
"""DeepSeek does not currently offer an embeddings endpoint — returns None."""
|
| 111 |
+
self.logger.warning("DeepSeekProvider does not support embeddings.")
|
| 112 |
+
return None
|
| 113 |
+
|
| 114 |
+
def clean_content(self, text: str) -> str:
|
| 115 |
+
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
| 116 |
+
text = re.sub(r'\[[^\]]*\]', '', text)
|
| 117 |
+
text = re.sub(r'\n+', '\n', text).strip()
|
| 118 |
+
return text
|
| 119 |
+
|
| 120 |
+
def web_search(self, query: str):
|
| 121 |
+
"""DeepSeek has no native web search — returns a not-supported notice."""
|
| 122 |
+
self.logger.warning("DeepSeekProvider.web_search is not natively supported.")
|
| 123 |
+
return {
|
| 124 |
+
"text": "Web search is not natively supported by the DeepSeek API.",
|
| 125 |
+
"references": []
|
| 126 |
+
}
|
stores/llm/providers/GeminiProvider.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
from stores.llm.LLMInterface import LLMInterface
|
| 4 |
+
import logging
|
| 5 |
+
import requests
|
| 6 |
+
import re
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class GeminiProvider(LLMInterface):
|
| 11 |
+
def __init__(self, url: str = None, model: str = None,
|
| 12 |
+
default_input_max_characters: int = 1000,
|
| 13 |
+
default_generation_max_output_tokens: int = 1000,
|
| 14 |
+
default_generation_temperature: float = 0.1, api_key: str = None):
|
| 15 |
+
self.url = url or "https://generativelanguage.googleapis.com/v1beta"
|
| 16 |
+
self.api_key = api_key or os.getenv("GEMINI_API_KEY")
|
| 17 |
+
self.model = model
|
| 18 |
+
self.generation_model_id = None
|
| 19 |
+
|
| 20 |
+
self.embedding_model = None
|
| 21 |
+
self.embedding_model_id = None
|
| 22 |
+
self.embedding_size = None
|
| 23 |
+
|
| 24 |
+
self.default_input_max_characters = default_input_max_characters
|
| 25 |
+
self.default_generation_max_output_tokens = default_generation_max_output_tokens
|
| 26 |
+
self.default_generation_temperature = default_generation_temperature
|
| 27 |
+
self.logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
def set_generation_model(self, model_id: str):
|
| 30 |
+
if model_id:
|
| 31 |
+
self.model = model_id
|
| 32 |
+
|
| 33 |
+
def set_embedding_model(self, model_id: str, embedding_size: int):
|
| 34 |
+
if model_id:
|
| 35 |
+
self.embedding_model = model_id
|
| 36 |
+
self.embedding_size = embedding_size
|
| 37 |
+
self.embedding_model_id = model_id
|
| 38 |
+
|
| 39 |
+
def process_text(self, text: str):
|
| 40 |
+
if not text:
|
| 41 |
+
return ""
|
| 42 |
+
return str(text).strip()
|
| 43 |
+
|
| 44 |
+
def _build_contents(self, prompt: str, chat_history: list) -> list:
|
| 45 |
+
"""Convert chat_history + prompt into Gemini's contents format."""
|
| 46 |
+
contents = []
|
| 47 |
+
for entry in chat_history:
|
| 48 |
+
role = entry.get("role", "user")
|
| 49 |
+
# Gemini uses 'model' instead of 'assistant'
|
| 50 |
+
if role == "assistant":
|
| 51 |
+
role = "model"
|
| 52 |
+
contents.append({
|
| 53 |
+
"role": role,
|
| 54 |
+
"parts": [{"text": entry.get("content", "")}]
|
| 55 |
+
})
|
| 56 |
+
contents.append({
|
| 57 |
+
"role": "user",
|
| 58 |
+
"parts": [{"text": prompt}]
|
| 59 |
+
})
|
| 60 |
+
return contents
|
| 61 |
+
|
| 62 |
+
def generate_text(self, prompt: str, chat_history: list = None,
|
| 63 |
+
max_output_tokens: int = None, temperature: float = None):
|
| 64 |
+
try:
|
| 65 |
+
chat_history = chat_history or []
|
| 66 |
+
clean_prompt = self.process_text(prompt)
|
| 67 |
+
|
| 68 |
+
contents = self._build_contents(clean_prompt, chat_history)
|
| 69 |
+
|
| 70 |
+
payload = {
|
| 71 |
+
"contents": contents,
|
| 72 |
+
"generationConfig": {
|
| 73 |
+
"maxOutputTokens": int(max_output_tokens or self.default_generation_max_output_tokens),
|
| 74 |
+
"temperature": float(temperature or self.default_generation_temperature),
|
| 75 |
+
}
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
url = (
|
| 79 |
+
f"{self.url.rstrip('/')}/models/{self.model}"
|
| 80 |
+
f":generateContent?key={self.api_key}"
|
| 81 |
+
)
|
| 82 |
+
headers = {"Content-Type": "application/json"}
|
| 83 |
+
|
| 84 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=6000)
|
| 85 |
+
if resp.status_code != 200:
|
| 86 |
+
self.logger.error("Gemini generate failed: %s %s", resp.status_code, resp.text)
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
data = resp.json()
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
generated_text = (
|
| 93 |
+
data["candidates"][0]["content"]["parts"][0]["text"].strip()
|
| 94 |
+
)
|
| 95 |
+
except (KeyError, IndexError, TypeError):
|
| 96 |
+
self.logger.error("Unexpected Gemini response structure: %s", data)
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
if not generated_text:
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
usage = data.get("usageMetadata", {})
|
| 103 |
+
return {
|
| 104 |
+
"model": self.model,
|
| 105 |
+
"response": generated_text,
|
| 106 |
+
"tokens_generated": usage.get("candidatesTokenCount"),
|
| 107 |
+
"total_duration_ms": None,
|
| 108 |
+
"prompt_eval_tokens": usage.get("promptTokenCount"),
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
except Exception as e:
|
| 112 |
+
self.logger.exception("Error in GeminiProvider.generate_text: %s", e)
|
| 113 |
+
return None
|
| 114 |
+
|
| 115 |
+
def embed_text(self, text: str, document_type: str = None):
|
| 116 |
+
try:
|
| 117 |
+
if not self.embedding_model:
|
| 118 |
+
self.logger.error("Embedding model is not set before calling embed_text()")
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
clean_text = self.process_text(text)
|
| 122 |
+
print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'")
|
| 123 |
+
if not clean_text:
|
| 124 |
+
return []
|
| 125 |
+
|
| 126 |
+
# Map document_type to Gemini task type
|
| 127 |
+
task_type_map = {
|
| 128 |
+
"search_document": "RETRIEVAL_DOCUMENT",
|
| 129 |
+
"search_query": "RETRIEVAL_QUERY",
|
| 130 |
+
"classification": "CLASSIFICATION",
|
| 131 |
+
"clustering": "CLUSTERING",
|
| 132 |
+
}
|
| 133 |
+
task_type = task_type_map.get(document_type, "RETRIEVAL_DOCUMENT")
|
| 134 |
+
|
| 135 |
+
payload = {
|
| 136 |
+
"model": f"models/{self.embedding_model}",
|
| 137 |
+
"content": {"parts": [{"text": clean_text}]},
|
| 138 |
+
"output_dimensionality": 768,
|
| 139 |
+
"taskType": task_type,
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
url = (
|
| 143 |
+
f"{self.url.rstrip('/')}/models/{self.embedding_model}"
|
| 144 |
+
f":embedContent?key={self.api_key}"
|
| 145 |
+
)
|
| 146 |
+
headers = {"Content-Type": "application/json"}
|
| 147 |
+
|
| 148 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=200)
|
| 149 |
+
if resp.status_code != 200:
|
| 150 |
+
print(f"[ERROR] Gemini embedding failed: {resp.status_code} {resp.text}")
|
| 151 |
+
return None
|
| 152 |
+
|
| 153 |
+
data = resp.json()
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
embedding = data["embedding"]["values"]
|
| 157 |
+
print(f"[DEBUG] Embedding length: {len(embedding)}")
|
| 158 |
+
return embedding
|
| 159 |
+
except (KeyError, TypeError):
|
| 160 |
+
print("[WARNING] 'embedding' key not found in response JSON")
|
| 161 |
+
return None
|
| 162 |
+
|
| 163 |
+
except Exception as e:
|
| 164 |
+
print(f"[EXCEPTION] Error in GeminiProvider.embed_text: {e}")
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
def construct_prompt(self, prompt: str, role: str):
|
| 168 |
+
return {
|
| 169 |
+
"role": role,
|
| 170 |
+
"content": self.process_text(prompt)
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
def embed_text_batch(self, texts: list[str], batch_size: int = 32):
|
| 174 |
+
self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}")
|
| 175 |
+
|
| 176 |
+
if not self.embedding_model:
|
| 177 |
+
self.logger.error("Embedding model not set")
|
| 178 |
+
return None
|
| 179 |
+
|
| 180 |
+
all_embeddings = []
|
| 181 |
+
|
| 182 |
+
url = (
|
| 183 |
+
f"{self.url.rstrip('/')}/models/{self.embedding_model}"
|
| 184 |
+
f":batchEmbedContents?key={self.api_key}"
|
| 185 |
+
)
|
| 186 |
+
headers = {"Content-Type": "application/json"}
|
| 187 |
+
|
| 188 |
+
for i in range(0, len(texts), batch_size):
|
| 189 |
+
time.sleep(5)
|
| 190 |
+
batch = texts[i:i + batch_size]
|
| 191 |
+
clean_batch = [self.process_text(t) for t in batch if t]
|
| 192 |
+
|
| 193 |
+
print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size}")
|
| 194 |
+
|
| 195 |
+
# Gemini batchEmbedContents takes a list of requests
|
| 196 |
+
requests_list = [
|
| 197 |
+
{
|
| 198 |
+
"model": f"models/{self.embedding_model}",
|
| 199 |
+
"content": {"parts": [{"text": t}]},
|
| 200 |
+
"taskType": "RETRIEVAL_DOCUMENT",
|
| 201 |
+
"output_dimensionality": 768, # ← add this
|
| 202 |
+
}
|
| 203 |
+
for t in clean_batch
|
| 204 |
+
]
|
| 205 |
+
payload = {"requests": requests_list}
|
| 206 |
+
|
| 207 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=200)
|
| 208 |
+
if resp.status_code != 200:
|
| 209 |
+
self.logger.error("Gemini embedding failed: %s %s", resp.status_code, resp.text)
|
| 210 |
+
return None
|
| 211 |
+
|
| 212 |
+
data = resp.json()
|
| 213 |
+
|
| 214 |
+
try:
|
| 215 |
+
embeddings = [item["values"] for item in data["embeddings"]]
|
| 216 |
+
except (KeyError, TypeError):
|
| 217 |
+
self.logger.error("No embeddings returned from Gemini")
|
| 218 |
+
return None
|
| 219 |
+
|
| 220 |
+
if not embeddings:
|
| 221 |
+
self.logger.error("No embeddings returned from Gemini")
|
| 222 |
+
return None
|
| 223 |
+
|
| 224 |
+
self.logger.debug(f"Received {len(embeddings)} embeddings")
|
| 225 |
+
all_embeddings.extend(embeddings)
|
| 226 |
+
|
| 227 |
+
self.logger.info(f"Total embeddings created: {len(all_embeddings)}")
|
| 228 |
+
return all_embeddings
|
| 229 |
+
|
| 230 |
+
def clean_content(self, text: str) -> str:
|
| 231 |
+
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
| 232 |
+
text = re.sub(r'\[[^\]]*\]', '', text)
|
| 233 |
+
text = re.sub(r'\n+', '\n', text).strip()
|
| 234 |
+
return text
|
| 235 |
+
|
| 236 |
+
def web_search(self, query: str):
|
| 237 |
+
"""
|
| 238 |
+
Gemini supports Google Search grounding via the tools parameter.
|
| 239 |
+
Uses generateContent with the googleSearch tool enabled.
|
| 240 |
+
"""
|
| 241 |
+
try:
|
| 242 |
+
payload = {
|
| 243 |
+
"contents": [{"role": "user", "parts": [{"text": query}]}],
|
| 244 |
+
"tools": [{"google_search": {}}],
|
| 245 |
+
"generationConfig": {
|
| 246 |
+
"maxOutputTokens": int(self.default_generation_max_output_tokens),
|
| 247 |
+
"temperature": float(self.default_generation_temperature),
|
| 248 |
+
}
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
url = (
|
| 252 |
+
f"{self.url.rstrip('/')}/models/{self.model}"
|
| 253 |
+
f":generateContent?key={self.api_key}"
|
| 254 |
+
)
|
| 255 |
+
headers = {"Content-Type": "application/json"}
|
| 256 |
+
|
| 257 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=6000)
|
| 258 |
+
|
| 259 |
+
if not resp or resp.status_code != 200:
|
| 260 |
+
return {
|
| 261 |
+
"text": "No relevant external results found.",
|
| 262 |
+
"references": []
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
data = resp.json()
|
| 266 |
+
|
| 267 |
+
combined_text = []
|
| 268 |
+
references = set()
|
| 269 |
+
|
| 270 |
+
try:
|
| 271 |
+
text_content = data["candidates"][0]["content"]["parts"][0]["text"]
|
| 272 |
+
combined_text.append(self.clean_content(text_content))
|
| 273 |
+
except (KeyError, IndexError, TypeError):
|
| 274 |
+
pass
|
| 275 |
+
|
| 276 |
+
# Extract grounding metadata URLs
|
| 277 |
+
try:
|
| 278 |
+
chunks = (
|
| 279 |
+
data["candidates"][0]
|
| 280 |
+
.get("groundingMetadata", {})
|
| 281 |
+
.get("groundingChunks", [])
|
| 282 |
+
)
|
| 283 |
+
for chunk in chunks:
|
| 284 |
+
web = chunk.get("web", {})
|
| 285 |
+
uri = web.get("uri", "")
|
| 286 |
+
if uri.startswith("http"):
|
| 287 |
+
references.add(uri)
|
| 288 |
+
except (KeyError, IndexError, TypeError):
|
| 289 |
+
pass
|
| 290 |
+
|
| 291 |
+
# Also scan response text for bare URLs
|
| 292 |
+
for found_url in re.findall(r"https?://[^\s)]+", "\n".join(combined_text)):
|
| 293 |
+
references.add(found_url)
|
| 294 |
+
|
| 295 |
+
return {
|
| 296 |
+
"text": "\n\n".join(combined_text[:3]),
|
| 297 |
+
"references": list(references)
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
except Exception as e:
|
| 301 |
+
self.logger.error("Gemini web search failed: %s", e)
|
| 302 |
+
return {
|
| 303 |
+
"text": f"Gemini search error: {str(e)}",
|
| 304 |
+
"references": []
|
| 305 |
+
}
|
stores/llm/providers/GroqProvider.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stores.llm.LLMInterface import LLMInterface
|
| 2 |
+
import logging
|
| 3 |
+
import requests
|
| 4 |
+
import re
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class GroqProvider(LLMInterface):
|
| 9 |
+
def __init__(self, url: str = None, model: str = None,
|
| 10 |
+
default_input_max_characters: int = 1000,
|
| 11 |
+
default_generation_max_output_tokens: int = 1000,
|
| 12 |
+
default_generation_temperature: float = 0.1, api_key: str = None):
|
| 13 |
+
self.url = url or "https://api.groq.com/openai/v1"
|
| 14 |
+
self.api_key = api_key or os.getenv("GROQ_API_KEY")
|
| 15 |
+
self.model = model
|
| 16 |
+
self.generation_model_id = None
|
| 17 |
+
|
| 18 |
+
self.embedding_model = None
|
| 19 |
+
self.embedding_model_id = None
|
| 20 |
+
self.embedding_size = None
|
| 21 |
+
|
| 22 |
+
self.default_input_max_characters = default_input_max_characters
|
| 23 |
+
self.default_generation_max_output_tokens = default_generation_max_output_tokens
|
| 24 |
+
self.default_generation_temperature = default_generation_temperature
|
| 25 |
+
self.logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
def set_generation_model(self, model_id: str):
|
| 28 |
+
if model_id:
|
| 29 |
+
self.model = model_id
|
| 30 |
+
|
| 31 |
+
def set_embedding_model(self, model_id: str, embedding_size: int):
|
| 32 |
+
if model_id:
|
| 33 |
+
self.embedding_model = model_id
|
| 34 |
+
self.embedding_size = embedding_size
|
| 35 |
+
self.embedding_model_id = model_id
|
| 36 |
+
|
| 37 |
+
def process_text(self, text: str):
|
| 38 |
+
if not text:
|
| 39 |
+
return ""
|
| 40 |
+
return str(text).strip()
|
| 41 |
+
|
| 42 |
+
def generate_text(self, prompt: str, chat_history: list = None,
|
| 43 |
+
max_output_tokens: int = None, temperature: float = None):
|
| 44 |
+
try:
|
| 45 |
+
chat_history = chat_history or []
|
| 46 |
+
clean_prompt = self.process_text(prompt)
|
| 47 |
+
|
| 48 |
+
messages = []
|
| 49 |
+
for entry in chat_history:
|
| 50 |
+
messages.append({
|
| 51 |
+
"role": entry.get("role", "user"),
|
| 52 |
+
"content": entry.get("content", "")
|
| 53 |
+
})
|
| 54 |
+
messages.append({"role": "user", "content": clean_prompt})
|
| 55 |
+
|
| 56 |
+
payload = {
|
| 57 |
+
"model": self.model,
|
| 58 |
+
"messages": messages,
|
| 59 |
+
"max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens),
|
| 60 |
+
"temperature": float(temperature or self.default_generation_temperature),
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
url = self.url.rstrip("/") + "/chat/completions"
|
| 64 |
+
headers = {
|
| 65 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 66 |
+
"Content-Type": "application/json",
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=6000)
|
| 70 |
+
if resp.status_code != 200:
|
| 71 |
+
self.logger.error("Groq generate failed: %s %s", resp.status_code, resp.text)
|
| 72 |
+
return None
|
| 73 |
+
|
| 74 |
+
data = resp.json()
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
generated_text = data["choices"][0]["message"]["content"].strip()
|
| 78 |
+
except (KeyError, IndexError, TypeError):
|
| 79 |
+
self.logger.error("Unexpected Groq response structure: %s", data)
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
if not generated_text:
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
usage = data.get("usage", {})
|
| 86 |
+
# Groq exposes x_groq.usage.total_time in seconds
|
| 87 |
+
total_time_ms = None
|
| 88 |
+
try:
|
| 89 |
+
total_time_ms = round(data["x_groq"]["usage"]["total_time"] * 1000, 2)
|
| 90 |
+
except (KeyError, TypeError):
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
return {
|
| 94 |
+
"model": data.get("model"),
|
| 95 |
+
"response": generated_text,
|
| 96 |
+
"tokens_generated": usage.get("completion_tokens"),
|
| 97 |
+
"total_duration_ms": total_time_ms,
|
| 98 |
+
"prompt_eval_tokens": usage.get("prompt_tokens"),
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
self.logger.exception("Error in GroqProvider.generate_text: %s", e)
|
| 103 |
+
return None
|
| 104 |
+
|
| 105 |
+
def embed_text(self, text: str, document_type: str = None):
|
| 106 |
+
"""Groq does not support embeddings — returns None."""
|
| 107 |
+
self.logger.warning("GroqProvider does not support embeddings.")
|
| 108 |
+
return None
|
| 109 |
+
|
| 110 |
+
def construct_prompt(self, prompt: str, role: str):
|
| 111 |
+
return {
|
| 112 |
+
"role": role,
|
| 113 |
+
"content": self.process_text(prompt)
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
def embed_text_batch(self, texts: list[str], batch_size: int = 32):
|
| 117 |
+
"""Groq does not support embeddings — returns None."""
|
| 118 |
+
self.logger.warning("GroqProvider does not support embeddings.")
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
def clean_content(self, text: str) -> str:
|
| 122 |
+
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
| 123 |
+
text = re.sub(r'\[[^\]]*\]', '', text)
|
| 124 |
+
text = re.sub(r'\n+', '\n', text).strip()
|
| 125 |
+
return text
|
| 126 |
+
|
| 127 |
+
def web_search(self, query: str):
|
| 128 |
+
"""Groq has no native web search — returns a not-supported notice."""
|
| 129 |
+
self.logger.warning("GroqProvider.web_search is not natively supported.")
|
| 130 |
+
return {
|
| 131 |
+
"text": "Web search is not natively supported by the Groq API.",
|
| 132 |
+
"references": []
|
| 133 |
+
}
|
stores/llm/providers/HuggingFaceProvider.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stores.llm.LLMInterface import LLMInterface
|
| 2 |
+
import logging
|
| 3 |
+
import requests
|
| 4 |
+
import re
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class HuggingFaceProvider(LLMInterface):
|
| 9 |
+
def __init__(self, url: str = None, model: str = None,
|
| 10 |
+
default_input_max_characters: int = 1000,
|
| 11 |
+
default_generation_max_output_tokens: int = 1000,
|
| 12 |
+
default_generation_temperature: float = 0.1, api_key: str = None):
|
| 13 |
+
# Supports both Inference API (serverless) and Inference Endpoints (dedicated)
|
| 14 |
+
self.url = url or "https://router.huggingface.co"
|
| 15 |
+
self.api_key = api_key or os.getenv("HF_API_KEY")
|
| 16 |
+
self.model = model
|
| 17 |
+
self.generation_model_id = None
|
| 18 |
+
|
| 19 |
+
self.embedding_model = None
|
| 20 |
+
self.embedding_model_id = None
|
| 21 |
+
self.embedding_size = None
|
| 22 |
+
|
| 23 |
+
self.default_input_max_characters = default_input_max_characters
|
| 24 |
+
self.default_generation_max_output_tokens = default_generation_max_output_tokens
|
| 25 |
+
self.default_generation_temperature = default_generation_temperature
|
| 26 |
+
self.logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
def set_generation_model(self, model_id: str):
|
| 29 |
+
if model_id:
|
| 30 |
+
self.model = model_id
|
| 31 |
+
|
| 32 |
+
def set_embedding_model(self, model_id: str, embedding_size: int):
|
| 33 |
+
if model_id:
|
| 34 |
+
self.embedding_model = model_id
|
| 35 |
+
self.embedding_size = embedding_size
|
| 36 |
+
self.embedding_model_id = model_id
|
| 37 |
+
|
| 38 |
+
def process_text(self, text: str):
|
| 39 |
+
if not text:
|
| 40 |
+
return ""
|
| 41 |
+
return str(text).strip()
|
| 42 |
+
|
| 43 |
+
def generate_text(self, prompt: str, chat_history: list = None,
|
| 44 |
+
max_output_tokens: int = None, temperature: float = None):
|
| 45 |
+
try:
|
| 46 |
+
chat_history = chat_history or []
|
| 47 |
+
clean_prompt = self.process_text(prompt)
|
| 48 |
+
|
| 49 |
+
messages = []
|
| 50 |
+
for entry in chat_history:
|
| 51 |
+
messages.append({
|
| 52 |
+
"role": entry.get("role", "user"),
|
| 53 |
+
"content": entry.get("content", "")
|
| 54 |
+
})
|
| 55 |
+
messages.append({"role": "user", "content": clean_prompt})
|
| 56 |
+
|
| 57 |
+
payload = {
|
| 58 |
+
"model": self.model,
|
| 59 |
+
"messages": messages,
|
| 60 |
+
"max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens),
|
| 61 |
+
"temperature": float(temperature or self.default_generation_temperature),
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
# HF Inference API (serverless): /v1/chat/completions (OpenAI-compatible)
|
| 65 |
+
url = self.url.rstrip("/") + "/v1/chat/completions"
|
| 66 |
+
headers = {
|
| 67 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 68 |
+
"Content-Type": "application/json",
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=6000)
|
| 72 |
+
if resp.status_code != 200:
|
| 73 |
+
self.logger.error("HuggingFace generate failed: %s %s", resp.status_code, resp.text)
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
data = resp.json()
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
generated_text = data["choices"][0]["message"]["content"].strip()
|
| 80 |
+
except (KeyError, IndexError, TypeError):
|
| 81 |
+
self.logger.error("Unexpected HuggingFace response structure: %s", data)
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
if not generated_text:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
usage = data.get("usage", {})
|
| 88 |
+
return {
|
| 89 |
+
"model": data.get("model"),
|
| 90 |
+
"response": generated_text,
|
| 91 |
+
"tokens_generated": usage.get("completion_tokens"),
|
| 92 |
+
"total_duration_ms": None,
|
| 93 |
+
"prompt_eval_tokens": usage.get("prompt_tokens"),
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
self.logger.exception("Error in HuggingFaceProvider.generate_text: %s", e)
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
def embed_text(self, text: str, document_type: str = None):
|
| 101 |
+
try:
|
| 102 |
+
if not self.embedding_model:
|
| 103 |
+
self.logger.error("Embedding model is not set before calling embed_text()")
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
clean_text = self.process_text(text)
|
| 107 |
+
print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'")
|
| 108 |
+
if not clean_text:
|
| 109 |
+
return []
|
| 110 |
+
|
| 111 |
+
payload = {"inputs": clean_text}
|
| 112 |
+
|
| 113 |
+
# Feature-extraction endpoint per model
|
| 114 |
+
url = f"https://router.huggingface.co/hf-inference/models/{self.embedding_model}/pipeline/feature-extraction"
|
| 115 |
+
headers = {
|
| 116 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 117 |
+
"Content-Type": "application/json",
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=200)
|
| 121 |
+
if resp.status_code != 200:
|
| 122 |
+
print(f"[ERROR] HuggingFace embedding failed: {resp.status_code} {resp.text}")
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
data = resp.json()
|
| 126 |
+
|
| 127 |
+
# HF returns a nested list: [[vector]] for single input
|
| 128 |
+
embedding = None
|
| 129 |
+
if isinstance(data, list):
|
| 130 |
+
if len(data) > 0 and isinstance(data[0], list):
|
| 131 |
+
embedding = data[0] # [[float, ...]] -> [float, ...]
|
| 132 |
+
elif len(data) > 0 and isinstance(data[0], float):
|
| 133 |
+
embedding = data # [float, ...] already flat
|
| 134 |
+
elif isinstance(data, dict) and "embedding" in data:
|
| 135 |
+
embedding = data["embedding"]
|
| 136 |
+
|
| 137 |
+
if embedding is not None:
|
| 138 |
+
print(f"[DEBUG] Embedding length: {len(embedding)}")
|
| 139 |
+
return embedding
|
| 140 |
+
|
| 141 |
+
print("[WARNING] 'embedding' key not found in response JSON")
|
| 142 |
+
return None
|
| 143 |
+
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f"[EXCEPTION] Error in HuggingFaceProvider.embed_text: {e}")
|
| 146 |
+
return None
|
| 147 |
+
|
| 148 |
+
def construct_prompt(self, prompt: str, role: str):
|
| 149 |
+
return {
|
| 150 |
+
"role": role,
|
| 151 |
+
"content": self.process_text(prompt)
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
def embed_text_batch(self, texts: list[str], batch_size: int = 32):
|
| 155 |
+
self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}")
|
| 156 |
+
|
| 157 |
+
if not self.embedding_model:
|
| 158 |
+
self.logger.error("Embedding model not set")
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
all_embeddings = []
|
| 162 |
+
|
| 163 |
+
url = f"https://router.huggingface.co/hf-inference/models/{self.embedding_model}/pipeline/feature-extraction"
|
| 164 |
+
headers = {
|
| 165 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 166 |
+
"Content-Type": "application/json",
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
for i in range(0, len(texts), batch_size):
|
| 170 |
+
batch = texts[i:i + batch_size]
|
| 171 |
+
clean_batch = [self.process_text(t) for t in batch if t]
|
| 172 |
+
|
| 173 |
+
print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size}")
|
| 174 |
+
|
| 175 |
+
payload = {"inputs": clean_batch}
|
| 176 |
+
|
| 177 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=200)
|
| 178 |
+
if resp.status_code != 200:
|
| 179 |
+
self.logger.error("HuggingFace embedding failed: %s %s", resp.status_code, resp.text)
|
| 180 |
+
return None
|
| 181 |
+
|
| 182 |
+
data = resp.json()
|
| 183 |
+
|
| 184 |
+
# Batch response: [[vec1], [vec2], ...] or [[f,f,...], [f,f,...]]
|
| 185 |
+
embeddings = None
|
| 186 |
+
if isinstance(data, list) and len(data) > 0:
|
| 187 |
+
if isinstance(data[0], list):
|
| 188 |
+
embeddings = data
|
| 189 |
+
elif isinstance(data[0], float):
|
| 190 |
+
embeddings = [data] # single vector returned flat
|
| 191 |
+
|
| 192 |
+
if not embeddings:
|
| 193 |
+
self.logger.error("No embeddings returned from HuggingFace")
|
| 194 |
+
return None
|
| 195 |
+
|
| 196 |
+
self.logger.debug(f"Received {len(embeddings)} embeddings")
|
| 197 |
+
all_embeddings.extend(embeddings)
|
| 198 |
+
|
| 199 |
+
self.logger.info(f"Total embeddings created: {len(all_embeddings)}")
|
| 200 |
+
return all_embeddings
|
| 201 |
+
|
| 202 |
+
def clean_content(self, text: str) -> str:
|
| 203 |
+
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
| 204 |
+
text = re.sub(r'\[[^\]]*\]', '', text)
|
| 205 |
+
text = re.sub(r'\n+', '\n', text).strip()
|
| 206 |
+
return text
|
| 207 |
+
|
| 208 |
+
def web_search(self, query: str):
|
| 209 |
+
"""HuggingFace Inference API has no native web search — returns a not-supported notice."""
|
| 210 |
+
self.logger.warning("HuggingFaceProvider.web_search is not natively supported.")
|
| 211 |
+
return {
|
| 212 |
+
"text": "Web search is not natively supported by the HuggingFace Inference API.",
|
| 213 |
+
"references": []
|
| 214 |
+
}
|
stores/llm/providers/MistralProvider.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
from stores.llm.LLMInterface import LLMInterface
|
| 4 |
+
import logging
|
| 5 |
+
import requests
|
| 6 |
+
import re
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MistralProvider(LLMInterface):
|
| 11 |
+
def __init__(self, url: str = None, model: str = None,
|
| 12 |
+
default_input_max_characters: int = 1000,
|
| 13 |
+
default_generation_max_output_tokens: int = 1000,
|
| 14 |
+
default_generation_temperature: float = 0.1, api_key: str = None):
|
| 15 |
+
self.url = url or "https://api.mistral.ai/v1"
|
| 16 |
+
self.api_key = api_key or os.getenv("MISTRAL_API_KEY")
|
| 17 |
+
self.model = model
|
| 18 |
+
self.generation_model_id = None
|
| 19 |
+
|
| 20 |
+
self.embedding_model = None
|
| 21 |
+
self.embedding_model_id = None
|
| 22 |
+
self.embedding_size = None
|
| 23 |
+
|
| 24 |
+
self.default_input_max_characters = default_input_max_characters
|
| 25 |
+
self.default_generation_max_output_tokens = default_generation_max_output_tokens
|
| 26 |
+
self.default_generation_temperature = default_generation_temperature
|
| 27 |
+
self.logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
def set_generation_model(self, model_id: str):
|
| 30 |
+
if model_id:
|
| 31 |
+
self.model = model_id
|
| 32 |
+
|
| 33 |
+
def set_embedding_model(self, model_id: str, embedding_size: int):
|
| 34 |
+
if model_id:
|
| 35 |
+
self.embedding_model = model_id
|
| 36 |
+
self.embedding_size = embedding_size
|
| 37 |
+
self.embedding_model_id = model_id
|
| 38 |
+
|
| 39 |
+
def process_text(self, text: str):
|
| 40 |
+
if not text:
|
| 41 |
+
return ""
|
| 42 |
+
return str(text).strip()
|
| 43 |
+
|
| 44 |
+
def generate_text(self, prompt: str, chat_history: list = None,
|
| 45 |
+
max_output_tokens: int = None, temperature: float = None):
|
| 46 |
+
try:
|
| 47 |
+
chat_history = chat_history or []
|
| 48 |
+
clean_prompt = self.process_text(prompt)
|
| 49 |
+
|
| 50 |
+
messages = []
|
| 51 |
+
for entry in chat_history:
|
| 52 |
+
messages.append({
|
| 53 |
+
"role": entry.get("role", "user"),
|
| 54 |
+
"content": entry.get("content", "")
|
| 55 |
+
})
|
| 56 |
+
messages.append({"role": "user", "content": clean_prompt})
|
| 57 |
+
|
| 58 |
+
payload = {
|
| 59 |
+
"model": self.model,
|
| 60 |
+
"messages": messages,
|
| 61 |
+
"max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens),
|
| 62 |
+
"temperature": float(temperature or self.default_generation_temperature),
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
url = self.url.rstrip("/") + "/chat/completions"
|
| 66 |
+
headers = {
|
| 67 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 68 |
+
"Content-Type": "application/json",
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=6000)
|
| 72 |
+
if resp.status_code != 200:
|
| 73 |
+
self.logger.error("Mistral generate failed: %s %s", resp.status_code, resp.text)
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
data = resp.json()
|
| 77 |
+
|
| 78 |
+
try:
|
| 79 |
+
generated_text = data["choices"][0]["message"]["content"].strip()
|
| 80 |
+
except (KeyError, IndexError, TypeError):
|
| 81 |
+
self.logger.error("Unexpected Mistral response structure: %s", data)
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
if not generated_text:
|
| 85 |
+
return None
|
| 86 |
+
|
| 87 |
+
usage = data.get("usage", {})
|
| 88 |
+
return {
|
| 89 |
+
"model": data.get("model"),
|
| 90 |
+
"response": generated_text,
|
| 91 |
+
"tokens_generated": usage.get("completion_tokens"),
|
| 92 |
+
"total_duration_ms": None,
|
| 93 |
+
"prompt_eval_tokens": usage.get("prompt_tokens"),
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
self.logger.exception("Error in MistralProvider.generate_text: %s", e)
|
| 98 |
+
return None
|
| 99 |
+
|
| 100 |
+
def embed_text(self, text: str, document_type: str = None):
|
| 101 |
+
try:
|
| 102 |
+
if not self.embedding_model:
|
| 103 |
+
self.logger.error("Embedding model is not set before calling embed_text()")
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
clean_text = self.process_text(text)
|
| 107 |
+
print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'")
|
| 108 |
+
if not clean_text:
|
| 109 |
+
return []
|
| 110 |
+
|
| 111 |
+
payload = {
|
| 112 |
+
"model": self.embedding_model,
|
| 113 |
+
"input": [clean_text],
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
url = self.url.rstrip("/") + "/embeddings"
|
| 117 |
+
headers = {
|
| 118 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 119 |
+
"Content-Type": "application/json",
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=200)
|
| 123 |
+
if resp.status_code != 200:
|
| 124 |
+
print(f"[ERROR] Mistral embedding failed: {resp.status_code} {resp.text}")
|
| 125 |
+
return None
|
| 126 |
+
|
| 127 |
+
data = resp.json()
|
| 128 |
+
|
| 129 |
+
try:
|
| 130 |
+
embedding = data["data"][0]["embedding"]
|
| 131 |
+
print(f"[DEBUG] Embedding length: {len(embedding)}")
|
| 132 |
+
return embedding
|
| 133 |
+
except (KeyError, IndexError, TypeError):
|
| 134 |
+
print("[WARNING] 'embedding' key not found in response JSON")
|
| 135 |
+
return None
|
| 136 |
+
|
| 137 |
+
except Exception as e:
|
| 138 |
+
print(f"[EXCEPTION] Error in MistralProvider.embed_text: {e}")
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
def construct_prompt(self, prompt: str, role: str):
|
| 142 |
+
return {
|
| 143 |
+
"role": role,
|
| 144 |
+
"content": self.process_text(prompt)
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
def embed_text_batch(self, texts: list[str], batch_size: int = 32):
|
| 148 |
+
self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}")
|
| 149 |
+
|
| 150 |
+
if not self.embedding_model:
|
| 151 |
+
self.logger.error("Embedding model not set")
|
| 152 |
+
return None
|
| 153 |
+
|
| 154 |
+
all_embeddings = []
|
| 155 |
+
url = self.url.rstrip("/") + "/embeddings"
|
| 156 |
+
headers = {
|
| 157 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 158 |
+
"Content-Type": "application/json",
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
for i in range(0, len(texts), batch_size):
|
| 162 |
+
time.sleep(5)
|
| 163 |
+
batch = texts[i:i + batch_size]
|
| 164 |
+
clean_batch = [self.process_text(t) for t in batch if t]
|
| 165 |
+
|
| 166 |
+
print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size}")
|
| 167 |
+
|
| 168 |
+
payload = {
|
| 169 |
+
"model": self.embedding_model,
|
| 170 |
+
"input": clean_batch,
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=200)
|
| 174 |
+
if resp.status_code != 200:
|
| 175 |
+
self.logger.error("Mistral embedding failed: %s %s", resp.status_code, resp.text)
|
| 176 |
+
return None
|
| 177 |
+
|
| 178 |
+
data = resp.json()
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
embeddings = [item["embedding"] for item in data["data"]]
|
| 182 |
+
except (KeyError, TypeError):
|
| 183 |
+
self.logger.error("No embeddings returned from Mistral")
|
| 184 |
+
return None
|
| 185 |
+
|
| 186 |
+
if not embeddings:
|
| 187 |
+
self.logger.error("No embeddings returned from Mistral")
|
| 188 |
+
return None
|
| 189 |
+
|
| 190 |
+
self.logger.debug(f"Received {len(embeddings)} embeddings")
|
| 191 |
+
all_embeddings.extend(embeddings)
|
| 192 |
+
|
| 193 |
+
self.logger.info(f"Total embeddings created: {len(all_embeddings)}")
|
| 194 |
+
return all_embeddings
|
| 195 |
+
|
| 196 |
+
def clean_content(self, text: str) -> str:
|
| 197 |
+
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
| 198 |
+
text = re.sub(r'\[[^\]]*\]', '', text)
|
| 199 |
+
text = re.sub(r'\n+', '\n', text).strip()
|
| 200 |
+
return text
|
| 201 |
+
|
| 202 |
+
def web_search(self, query: str):
|
| 203 |
+
"""Mistral has no native web search — returns a not-supported notice."""
|
| 204 |
+
self.logger.warning("MistralProvider.web_search is not natively supported.")
|
| 205 |
+
return {
|
| 206 |
+
"text": "Web search is not natively supported by the Mistral API.",
|
| 207 |
+
"references": []
|
| 208 |
+
}
|
stores/llm/providers/OllamaProvider.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stores.llm.LLMInterface import LLMInterface
|
| 2 |
+
import logging
|
| 3 |
+
import requests
|
| 4 |
+
import re
|
| 5 |
+
import ollama
|
| 6 |
+
import os
|
| 7 |
+
class OllamaProvider(LLMInterface):
|
| 8 |
+
def __init__(self, url: str=None, model: str=None,
|
| 9 |
+
default_input_max_characters: int=1000,
|
| 10 |
+
default_generation_max_output_tokens: int=1000,
|
| 11 |
+
default_generation_temperature: float=0.1, api_key: str=None):
|
| 12 |
+
self.url = url or "http://localhost:11434"
|
| 13 |
+
self.api_key = api_key or os.getenv("OLLAMA_API_KEY")
|
| 14 |
+
self.model = model
|
| 15 |
+
self.generation_model_id = None
|
| 16 |
+
|
| 17 |
+
self.embedding_model = None
|
| 18 |
+
self.embedding_model_id = None
|
| 19 |
+
self.embedding_size = None
|
| 20 |
+
|
| 21 |
+
self.default_input_max_characters = default_input_max_characters
|
| 22 |
+
self.default_generation_max_output_tokens = default_generation_max_output_tokens
|
| 23 |
+
self.default_generation_temperature = default_generation_temperature
|
| 24 |
+
self.logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
def set_generation_model(self, model_id: str):
|
| 27 |
+
if model_id:
|
| 28 |
+
self.model = model_id
|
| 29 |
+
|
| 30 |
+
def set_embedding_model(self, model_id: str, embedding_size: int):
|
| 31 |
+
if model_id:
|
| 32 |
+
self.embedding_model = model_id
|
| 33 |
+
self.embedding_size = embedding_size
|
| 34 |
+
self.embedding_model_id = model_id
|
| 35 |
+
|
| 36 |
+
def process_text(self, text: str):
|
| 37 |
+
if not text:
|
| 38 |
+
return ""
|
| 39 |
+
return str(text).strip()
|
| 40 |
+
|
| 41 |
+
def generate_text(self, prompt: str, chat_history: list = None,
|
| 42 |
+
max_output_tokens: int = None, temperature: float = None):
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
chat_history = chat_history or [] # safe handling
|
| 47 |
+
clean_prompt = self.process_text(prompt)
|
| 48 |
+
|
| 49 |
+
# Build payload with correct Ollama keys
|
| 50 |
+
payload = {
|
| 51 |
+
"model": self.model,
|
| 52 |
+
"prompt": clean_prompt,
|
| 53 |
+
"stream": False,
|
| 54 |
+
"num_predict": int(max_output_tokens or self.default_generation_max_output_tokens),
|
| 55 |
+
"temperature": float(temperature or self.default_generation_temperature),
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
url = self.url.rstrip("/") + "/api/generate"
|
| 59 |
+
headers = {}
|
| 60 |
+
if self.api_key:
|
| 61 |
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
| 62 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=6000)
|
| 63 |
+
if resp.status_code != 200:
|
| 64 |
+
self.logger.error("Ollama generate failed: %s %s", resp.status_code, resp.text)
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
data = resp.json()
|
| 68 |
+
|
| 69 |
+
# Extract final generated text correctly
|
| 70 |
+
generated_text = data.get("response", "").strip()
|
| 71 |
+
|
| 72 |
+
# If nothing generated, treat as failure
|
| 73 |
+
if not generated_text:
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
# Return clean JSON instead of raw text
|
| 77 |
+
return {
|
| 78 |
+
"model": data.get("model"),
|
| 79 |
+
"response": generated_text,
|
| 80 |
+
"tokens_generated": data.get("eval_count"),
|
| 81 |
+
"total_duration_ms": round(data.get("total_duration", 0) / 1e6, 2),
|
| 82 |
+
"prompt_eval_tokens": data.get("prompt_eval_count"),
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
self.logger.exception("Error in OllamaProvider.generate_text: %s", e)
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
def embed_text(self, text: str, document_type: str = None):
|
| 90 |
+
"""Return an embedding vector from Ollama."""
|
| 91 |
+
try:
|
| 92 |
+
if not self.embedding_model:
|
| 93 |
+
self.logger.error("Embedding model is not set before calling embed_text()")
|
| 94 |
+
return None
|
| 95 |
+
|
| 96 |
+
clean_text = self.process_text(text)
|
| 97 |
+
print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'")
|
| 98 |
+
if not clean_text:
|
| 99 |
+
return []
|
| 100 |
+
|
| 101 |
+
payload = {
|
| 102 |
+
"model": self.embedding_model,
|
| 103 |
+
"input": clean_text
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
url = self.url.rstrip("/") + "/api/embed"
|
| 107 |
+
headers = {}
|
| 108 |
+
if self.api_key:
|
| 109 |
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
| 110 |
+
|
| 111 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=400)
|
| 112 |
+
if resp.status_code != 200:
|
| 113 |
+
print(f"[ERROR] Ollama embedding failed: {resp.status_code} {resp.text}")
|
| 114 |
+
return None
|
| 115 |
+
|
| 116 |
+
data = resp.json()
|
| 117 |
+
|
| 118 |
+
# Expected format: { "embedding": [...] }
|
| 119 |
+
if "embedding" in data:
|
| 120 |
+
print(f"[DEBUG] Embedding length: {len(data['embedding'])}")
|
| 121 |
+
return data["embedding"]
|
| 122 |
+
elif "embeddings" in data:
|
| 123 |
+
return data["embeddings"][0]
|
| 124 |
+
|
| 125 |
+
print("[WARNING] 'embedding' key not found in response JSON")
|
| 126 |
+
return None
|
| 127 |
+
|
| 128 |
+
except Exception as e:
|
| 129 |
+
print(f"[EXCEPTION] Error in OllamaProvider.embed_text: {e}")
|
| 130 |
+
return None
|
| 131 |
+
|
| 132 |
+
def construct_prompt(self, prompt: str, role: str):
|
| 133 |
+
return {
|
| 134 |
+
"role": role,
|
| 135 |
+
"content": self.process_text(prompt)
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
# def embed_text_batch(self, texts: list[str], batch_size: int = 32):
|
| 139 |
+
|
| 140 |
+
# self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}")
|
| 141 |
+
|
| 142 |
+
# if not self.embedding_model:
|
| 143 |
+
# self.logger.error("Embedding model not set")
|
| 144 |
+
# return None
|
| 145 |
+
|
| 146 |
+
# all_embeddings = []
|
| 147 |
+
|
| 148 |
+
# url = self.url.rstrip("/") + "/api/embed"
|
| 149 |
+
|
| 150 |
+
# for i in range(0, len(texts), batch_size):
|
| 151 |
+
# batch = texts[i:i + batch_size]
|
| 152 |
+
|
| 153 |
+
# clean_batch = [self.process_text(t) for t in batch if t]
|
| 154 |
+
|
| 155 |
+
# print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size} Progress = {i+batch_size}")
|
| 156 |
+
|
| 157 |
+
# payload = {
|
| 158 |
+
# "model": self.embedding_model,
|
| 159 |
+
# "input": clean_batch
|
| 160 |
+
# }
|
| 161 |
+
|
| 162 |
+
# resp = requests.post(url, json=payload, timeout=400)
|
| 163 |
+
|
| 164 |
+
# if resp.status_code != 200:
|
| 165 |
+
# self.logger.error("Ollama embedding failed: %s %s", resp.status_code, resp.text)
|
| 166 |
+
# return None
|
| 167 |
+
|
| 168 |
+
# data = resp.json()
|
| 169 |
+
|
| 170 |
+
# embeddings = data.get("embeddings")
|
| 171 |
+
|
| 172 |
+
# if not embeddings:
|
| 173 |
+
# self.logger.error("No embeddings returned from Ollama")
|
| 174 |
+
# return None
|
| 175 |
+
|
| 176 |
+
# self.logger.debug(f"Received {len(embeddings)} embeddings")
|
| 177 |
+
|
| 178 |
+
# all_embeddings.extend(embeddings)
|
| 179 |
+
|
| 180 |
+
# self.logger.info(f"Total embeddings created: {len(all_embeddings)}")
|
| 181 |
+
|
| 182 |
+
# return all_embeddings
|
| 183 |
+
|
| 184 |
+
def embed_text_batch(self, texts: list[str], batch_size: int = 64):
|
| 185 |
+
"""
|
| 186 |
+
Batch embedding for a list of texts, compatible with both /api/embed (new) and /api/embeddings (legacy).
|
| 187 |
+
Logs progress and returns a list of embedding vectors.
|
| 188 |
+
"""
|
| 189 |
+
all_embeddings = []
|
| 190 |
+
|
| 191 |
+
endpoints = ["/api/embed", "/api/embeddings"]
|
| 192 |
+
headers = {"Content-Type": "application/json"}
|
| 193 |
+
if self.api_key:
|
| 194 |
+
headers["Authorization"] = f"Bearer {self.api_key}"
|
| 195 |
+
|
| 196 |
+
total_texts = len(texts)
|
| 197 |
+
self.logger.info(f"Starting batch embedding of {total_texts} texts with batch_size={batch_size}")
|
| 198 |
+
|
| 199 |
+
for ep in endpoints:
|
| 200 |
+
try:
|
| 201 |
+
for i in range(0, total_texts, batch_size):
|
| 202 |
+
batch = texts[i:i + batch_size]
|
| 203 |
+
clean_batch = [self.process_text(t) for t in batch if t]
|
| 204 |
+
|
| 205 |
+
payload = {"model": self.embedding_model}
|
| 206 |
+
|
| 207 |
+
if ep == "/api/embed":
|
| 208 |
+
payload["input"] = clean_batch
|
| 209 |
+
resp = requests.post(self.url.rstrip("/") + ep, json=payload, headers=headers, timeout=400)
|
| 210 |
+
if resp.status_code != 200:
|
| 211 |
+
self.logger.warning(
|
| 212 |
+
"Batch embedding failed at %s: %s %s", ep, resp.status_code, resp.text
|
| 213 |
+
)
|
| 214 |
+
continue
|
| 215 |
+
|
| 216 |
+
data = resp.json()
|
| 217 |
+
embeddings = data.get("embeddings") or ([data.get("embedding")] if "embedding" in data else [])
|
| 218 |
+
all_embeddings.extend(embeddings)
|
| 219 |
+
|
| 220 |
+
else:
|
| 221 |
+
# Legacy endpoint: send individually
|
| 222 |
+
for j, t in enumerate(clean_batch):
|
| 223 |
+
payload_legacy = {"model": self.embedding_model, "prompt": t}
|
| 224 |
+
resp = requests.post(self.url.rstrip("/") + ep, json=payload_legacy, headers=headers, timeout=400)
|
| 225 |
+
if resp.status_code != 200:
|
| 226 |
+
self.logger.warning(
|
| 227 |
+
"Legacy embedding failed at %s: %s %s", ep, resp.status_code, resp.text
|
| 228 |
+
)
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
data = resp.json()
|
| 232 |
+
if "embedding" in data:
|
| 233 |
+
all_embeddings.append(data["embedding"])
|
| 234 |
+
self.logger.info(f"Embedded {i+j+1}/{total_texts} texts using legacy endpoint")
|
| 235 |
+
|
| 236 |
+
# Log batch progress
|
| 237 |
+
self.logger.info(f"Embedded {min(i+batch_size, total_texts)}/{total_texts} texts using {ep}")
|
| 238 |
+
|
| 239 |
+
if all_embeddings:
|
| 240 |
+
self.logger.info(f"Finished embedding {len(all_embeddings)}/{total_texts} texts successfully")
|
| 241 |
+
break # stop after successful endpoint
|
| 242 |
+
|
| 243 |
+
except Exception as e:
|
| 244 |
+
self.logger.exception("Batch embedding error at %s: %s", ep, e)
|
| 245 |
+
|
| 246 |
+
return all_embeddings
|
| 247 |
+
|
| 248 |
+
def clean_content(self, text: str) -> str:
|
| 249 |
+
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
| 250 |
+
text = re.sub(r'\[[^\]]*\]', '', text)
|
| 251 |
+
text = re.sub(r'\n+', '\n', text).strip()
|
| 252 |
+
return text
|
| 253 |
+
|
| 254 |
+
def web_search(self, query: str):
|
| 255 |
+
"""Use Ollama client to perform web search and return cleaned text + references."""
|
| 256 |
+
try:
|
| 257 |
+
# Use your old working Ollama client
|
| 258 |
+
OLLAMA_API_KEY = os.getenv("OLLAMA_API_KEY")
|
| 259 |
+
ollama_client = ollama.Client(headers={'Authorization': 'Bearer ' + OLLAMA_API_KEY})
|
| 260 |
+
response = ollama_client.web_search(query)
|
| 261 |
+
|
| 262 |
+
if not response or "results" not in response or len(response["results"]) == 0:
|
| 263 |
+
return {
|
| 264 |
+
"text": "No relevant external results found.",
|
| 265 |
+
"references": []
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
combined_text = []
|
| 269 |
+
references = set()
|
| 270 |
+
|
| 271 |
+
for item in response["results"]:
|
| 272 |
+
text = self.clean_content(item.content)
|
| 273 |
+
combined_text.append(text)
|
| 274 |
+
|
| 275 |
+
urls = re.findall(r"https?://[^\s)]+", item.content)
|
| 276 |
+
for url in urls:
|
| 277 |
+
references.add(url)
|
| 278 |
+
|
| 279 |
+
if hasattr(item, "url") and item.url:
|
| 280 |
+
references.add(item.url)
|
| 281 |
+
|
| 282 |
+
return {
|
| 283 |
+
"text": "\n\n".join(combined_text[:3]),
|
| 284 |
+
"references": list(references)
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
except Exception as e:
|
| 288 |
+
self.logger.error("Ollama web search failed: %s", e)
|
| 289 |
+
return {
|
| 290 |
+
"text": f"Ollama search error: {str(e)}",
|
| 291 |
+
"references": []
|
| 292 |
+
}
|
stores/llm/providers/OpenAIProvider.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..LLMInterface import LLMInterface
|
| 2 |
+
from ..LLMEnums import OpenAIEnums
|
| 3 |
+
from openai import OpenAI
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
class OpenAIProvider(LLMInterface):
|
| 7 |
+
def __init__(self, api_key: str, api_url: str = None,
|
| 8 |
+
default_input_max_characters: int = 1000,
|
| 9 |
+
default_generation_max_output_tokens: int = 1000,
|
| 10 |
+
default_generation_temperature: float = 0.1):
|
| 11 |
+
self.api_key = api_key
|
| 12 |
+
self.api_url = api_url
|
| 13 |
+
self.default_input_max_characters = default_input_max_characters
|
| 14 |
+
self.default_generation_max_output_tokens = default_generation_max_output_tokens
|
| 15 |
+
self.default_generation_temperature = default_generation_temperature
|
| 16 |
+
|
| 17 |
+
self.generation_model_id = None
|
| 18 |
+
self.embedding_model_id = None
|
| 19 |
+
self.embedding_size = None
|
| 20 |
+
|
| 21 |
+
self.client = OpenAI(api_key=self.api_key, base_url=self.api_url)
|
| 22 |
+
self.logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
def set_generation_model(self, model_id: str):
|
| 25 |
+
self.generation_model_id = model_id
|
| 26 |
+
|
| 27 |
+
def set_embedding_model(self, model_id: str, embedding_size: int):
|
| 28 |
+
self.embedding_model_id = model_id
|
| 29 |
+
self.embedding_size = embedding_size
|
| 30 |
+
|
| 31 |
+
def process_text(self, text: str):
|
| 32 |
+
return text[:self.default_input_max_characters].strip()
|
| 33 |
+
|
| 34 |
+
def generate_text(self, prompt: str, chat_history: list = None,
|
| 35 |
+
max_output_tokens: int = None, temperature: float = None):
|
| 36 |
+
if not self.client:
|
| 37 |
+
self.logger.error("OpenAI client was not initialized")
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
if not self.generation_model_id:
|
| 41 |
+
self.logger.error("OpenAI generation model not set")
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
max_output_tokens = max_output_tokens or self.default_generation_max_output_tokens
|
| 45 |
+
temperature = temperature or self.default_generation_temperature
|
| 46 |
+
|
| 47 |
+
messages = chat_history[:] if chat_history else []
|
| 48 |
+
messages.append(self.construct_prompt(prompt, OpenAIEnums.USER.value))
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
response = self.client.chat.completions.create(
|
| 52 |
+
model=self.generation_model_id,
|
| 53 |
+
messages=messages,
|
| 54 |
+
max_completion_tokens=max_output_tokens,
|
| 55 |
+
temperature=temperature
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
if (not response or not response.choices
|
| 59 |
+
or not response.choices[0].message
|
| 60 |
+
or not response.choices[0].message.content):
|
| 61 |
+
self.logger.error("Invalid OpenAI response format")
|
| 62 |
+
return None
|
| 63 |
+
|
| 64 |
+
return response.choices[0].message.content
|
| 65 |
+
|
| 66 |
+
except Exception as e:
|
| 67 |
+
self.logger.exception("Error while generating text with OpenAI: %s", e)
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
def embed_text_batch(self, texts: list[str], batch_size: int = 32):
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
def embed_text(self, text: str, document_type: str = None):
|
| 74 |
+
if not self.client:
|
| 75 |
+
self.logger.error("OpenAI client was not initialized")
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
if not self.embedding_model_id:
|
| 79 |
+
self.logger.error("OpenAI embedding model not set")
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
try:
|
| 83 |
+
response = self.client.embeddings.create(
|
| 84 |
+
model=self.embedding_model_id,
|
| 85 |
+
input=text
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if not response or not response.data or not response.data[0].embedding:
|
| 89 |
+
self.logger.error("Invalid OpenAI embedding response")
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
return response.data[0].embedding
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
self.logger.exception("Error while embedding text with OpenAI: %s", e)
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
def construct_prompt(self, prompt: str, role: str):
|
| 99 |
+
return {
|
| 100 |
+
"role": role,
|
| 101 |
+
"content": self.process_text(prompt)
|
| 102 |
+
}
|
stores/llm/providers/OpenRouterProvider.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from stores.llm.LLMInterface import LLMInterface
|
| 2 |
+
import logging
|
| 3 |
+
import requests
|
| 4 |
+
import re
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class OpenRouterProvider(LLMInterface):
|
| 9 |
+
def __init__(self, url: str = None, model: str = None,
|
| 10 |
+
default_input_max_characters: int = 1000,
|
| 11 |
+
default_generation_max_output_tokens: int = 1000,
|
| 12 |
+
default_generation_temperature: float = 0.1, api_key: str = None):
|
| 13 |
+
self.url = url or "https://openrouter.ai/api/v1"
|
| 14 |
+
self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
|
| 15 |
+
self.model = model
|
| 16 |
+
self.generation_model_id = None
|
| 17 |
+
|
| 18 |
+
self.embedding_model = None
|
| 19 |
+
self.embedding_model_id = None
|
| 20 |
+
self.embedding_size = None
|
| 21 |
+
|
| 22 |
+
self.default_input_max_characters = default_input_max_characters
|
| 23 |
+
self.default_generation_max_output_tokens = default_generation_max_output_tokens
|
| 24 |
+
self.default_generation_temperature = default_generation_temperature
|
| 25 |
+
self.logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
def set_generation_model(self, model_id: str):
|
| 28 |
+
if model_id:
|
| 29 |
+
self.model = model_id
|
| 30 |
+
|
| 31 |
+
def set_embedding_model(self, model_id: str, embedding_size: int):
|
| 32 |
+
if model_id:
|
| 33 |
+
self.embedding_model = model_id
|
| 34 |
+
self.embedding_size = embedding_size
|
| 35 |
+
self.embedding_model_id = model_id
|
| 36 |
+
|
| 37 |
+
def process_text(self, text: str):
|
| 38 |
+
if not text:
|
| 39 |
+
return ""
|
| 40 |
+
return str(text).strip()
|
| 41 |
+
|
| 42 |
+
def generate_text(self, prompt: str, chat_history: list = None,
|
| 43 |
+
max_output_tokens: int = None, temperature: float = None):
|
| 44 |
+
try:
|
| 45 |
+
chat_history = chat_history or []
|
| 46 |
+
clean_prompt = self.process_text(prompt)
|
| 47 |
+
|
| 48 |
+
messages = []
|
| 49 |
+
for entry in chat_history:
|
| 50 |
+
messages.append({
|
| 51 |
+
"role": entry.get("role", "user"),
|
| 52 |
+
"content": entry.get("content", "")
|
| 53 |
+
})
|
| 54 |
+
messages.append({"role": "user", "content": clean_prompt})
|
| 55 |
+
|
| 56 |
+
payload = {
|
| 57 |
+
"model": self.model,
|
| 58 |
+
"messages": messages,
|
| 59 |
+
"max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens),
|
| 60 |
+
"temperature": float(temperature or self.default_generation_temperature),
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
url = self.url.rstrip("/") + "/chat/completions"
|
| 64 |
+
headers = {
|
| 65 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 66 |
+
"Content-Type": "application/json",
|
| 67 |
+
# Recommended by OpenRouter for usage tracking
|
| 68 |
+
"HTTP-Referer": os.getenv("OPENROUTER_SITE_URL", "http://localhost"),
|
| 69 |
+
"X-Title": os.getenv("OPENROUTER_APP_NAME", "LLMApp"),
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=6000)
|
| 73 |
+
if resp.status_code != 200:
|
| 74 |
+
self.logger.error("OpenRouter generate failed: %s %s", resp.status_code, resp.text)
|
| 75 |
+
return None
|
| 76 |
+
|
| 77 |
+
data = resp.json()
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
generated_text = data["choices"][0]["message"]["content"].strip()
|
| 81 |
+
except (KeyError, IndexError, TypeError):
|
| 82 |
+
self.logger.error("Unexpected OpenRouter response structure: %s", data)
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
if not generated_text:
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
usage = data.get("usage", {})
|
| 89 |
+
return {
|
| 90 |
+
"model": data.get("model"),
|
| 91 |
+
"response": generated_text,
|
| 92 |
+
"tokens_generated": usage.get("completion_tokens"),
|
| 93 |
+
"total_duration_ms": None,
|
| 94 |
+
"prompt_eval_tokens": usage.get("prompt_tokens"),
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
except Exception as e:
|
| 98 |
+
self.logger.exception("Error in OpenRouterProvider.generate_text: %s", e)
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
def embed_text(self, text: str, document_type: str = None):
|
| 102 |
+
"""OpenRouter does not support embeddings natively — returns None."""
|
| 103 |
+
self.logger.warning("OpenRouterProvider does not support embeddings.")
|
| 104 |
+
return None
|
| 105 |
+
|
| 106 |
+
def construct_prompt(self, prompt: str, role: str):
|
| 107 |
+
return {
|
| 108 |
+
"role": role,
|
| 109 |
+
"content": self.process_text(prompt)
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
def embed_text_batch(self, texts: list[str], batch_size: int = 32):
|
| 113 |
+
"""OpenRouter does not support embeddings natively — returns None."""
|
| 114 |
+
self.logger.warning("OpenRouterProvider does not support embeddings.")
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
def clean_content(self, text: str) -> str:
|
| 118 |
+
text = re.sub(r'\[.*?\]\(.*?\)', '', text)
|
| 119 |
+
text = re.sub(r'\[[^\]]*\]', '', text)
|
| 120 |
+
text = re.sub(r'\n+', '\n', text).strip()
|
| 121 |
+
return text
|
| 122 |
+
|
| 123 |
+
def web_search(self, query: str):
|
| 124 |
+
"""
|
| 125 |
+
OpenRouter supports online models (e.g. perplexity/sonar-online) that have
|
| 126 |
+
built-in web search. Route the query through one of those models if available,
|
| 127 |
+
otherwise fall back to a not-supported notice.
|
| 128 |
+
"""
|
| 129 |
+
try:
|
| 130 |
+
online_model = os.getenv("OPENROUTER_SEARCH_MODEL", "perplexity/sonar-online")
|
| 131 |
+
|
| 132 |
+
payload = {
|
| 133 |
+
"model": online_model,
|
| 134 |
+
"messages": [{"role": "user", "content": query}],
|
| 135 |
+
"max_tokens": int(self.default_generation_max_output_tokens),
|
| 136 |
+
"temperature": float(self.default_generation_temperature),
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
url = self.url.rstrip("/") + "/chat/completions"
|
| 140 |
+
headers = {
|
| 141 |
+
"Authorization": f"Bearer {self.api_key}",
|
| 142 |
+
"Content-Type": "application/json",
|
| 143 |
+
"HTTP-Referer": os.getenv("OPENROUTER_SITE_URL", "http://localhost"),
|
| 144 |
+
"X-Title": os.getenv("OPENROUTER_APP_NAME", "LLMApp"),
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
resp = requests.post(url, json=payload, headers=headers, timeout=6000)
|
| 148 |
+
if not resp or resp.status_code != 200:
|
| 149 |
+
return {
|
| 150 |
+
"text": "No relevant external results found.",
|
| 151 |
+
"references": []
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
data = resp.json()
|
| 155 |
+
|
| 156 |
+
combined_text = []
|
| 157 |
+
references = set()
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
text_content = data["choices"][0]["message"]["content"]
|
| 161 |
+
combined_text.append(self.clean_content(text_content))
|
| 162 |
+
except (KeyError, IndexError, TypeError):
|
| 163 |
+
pass
|
| 164 |
+
|
| 165 |
+
# Extract any URLs from the response text
|
| 166 |
+
for found_url in re.findall(r"https?://[^\s)]+", "\n".join(combined_text)):
|
| 167 |
+
references.add(found_url)
|
| 168 |
+
|
| 169 |
+
return {
|
| 170 |
+
"text": "\n\n".join(combined_text[:3]),
|
| 171 |
+
"references": list(references)
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
except Exception as e:
|
| 175 |
+
self.logger.error("OpenRouter web search failed: %s", e)
|
| 176 |
+
return {
|
| 177 |
+
"text": f"OpenRouter search error: {str(e)}",
|
| 178 |
+
"references": []
|
| 179 |
+
}
|
stores/llm/providers/__init__.py
ADDED
|
File without changes
|