MinaNasser commited on
Commit
1bc3f18
·
1 Parent(s): 9676e57
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .env +115 -0
  2. .gitignore +17 -0
  3. Dockerfile +36 -0
  4. README.md +21 -4
  5. celery_app.py +30 -0
  6. config.py +64 -0
  7. docker-entrypoint.sh +14 -0
  8. generation/AssistantRagGenerator.py +201 -0
  9. generation/ExamAnswer.py +314 -0
  10. generation/ExamRagGenerator.py +460 -0
  11. generation/__init__.py +0 -0
  12. generation/answer_models.py +51 -0
  13. generation/parsing_utils.py +51 -0
  14. generation/prompts.py +250 -0
  15. indexing/indexingController.py +111 -0
  16. ingestion/chunkers/__init__.py +0 -0
  17. ingestion/chunkers/fixed_chunker.py +10 -0
  18. ingestion/chunkers/recursive_chunker.py +9 -0
  19. ingestion/loaders/File_loader.py +57 -0
  20. ingestion/loaders/__init__.py +1 -0
  21. ingestion/loaders/docx_loader.py +89 -0
  22. ingestion/loaders/md_loader.py +48 -0
  23. ingestion/loaders/normalization.py +35 -0
  24. ingestion/loaders/pdf_loader.py +66 -0
  25. ingestion/loaders/txt_loader.py +38 -0
  26. ingestion/pdf_outline.py +62 -0
  27. main.py +20 -0
  28. requirements.txt +19 -0
  29. routes/__init__.py +0 -0
  30. routes/assisstant_rag.py +165 -0
  31. routes/base.py +45 -0
  32. routes/exam_grading_router.py +122 -0
  33. routes/exam_router.py +15 -0
  34. routes/schemas/Exam_Models.py +180 -0
  35. routes/schemas/Requests_Models.py +24 -0
  36. routes/schemas/__init__.py +0 -0
  37. stores/llm/LLMEnums.py +28 -0
  38. stores/llm/LLMInterface.py +24 -0
  39. stores/llm/LLMProviderFactory.py +93 -0
  40. stores/llm/__init__.py +0 -0
  41. stores/llm/providers/CohereProvider.py +395 -0
  42. stores/llm/providers/DeepSeekProvider.py +126 -0
  43. stores/llm/providers/GeminiProvider.py +305 -0
  44. stores/llm/providers/GroqProvider.py +133 -0
  45. stores/llm/providers/HuggingFaceProvider.py +214 -0
  46. stores/llm/providers/MistralProvider.py +208 -0
  47. stores/llm/providers/OllamaProvider.py +292 -0
  48. stores/llm/providers/OpenAIProvider.py +102 -0
  49. stores/llm/providers/OpenRouterProvider.py +179 -0
  50. stores/llm/providers/__init__.py +0 -0
.env ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ APP_NAME="IntegraRAG"
2
+ DEBUG=False
3
+ CustomLoaders=False
4
+
5
+ # ---------- QDRANT ---------- Choose One
6
+ # QDRANT_TYPE="local"
7
+ # QDRANT_DOCKER_URL=""
8
+ # QDRANT_API_KEY=""
9
+
10
+ # QDRANT_TYPE="docker"
11
+ # QDRANT_DOCKER_URL="http://localhost:6333/"
12
+ # QDRANT_API_KEY=""
13
+
14
+ QDRANT_TYPE="cloud"
15
+ QDRANT_DOCKER_URL="https://d7e287d8-903d-436c-854c-03cbef9e4edb.us-east4-0.gcp.cloud.qdrant.io"
16
+ QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.NRbT0QPl7isuBKvdtganh89xa2DeMgKXZ3gSJngexQg"
17
+
18
+
19
+ # ---------- REDIS ----------
20
+ REDIS_HOST="rediss://default:gQAAAAAAAS-BAAIncDFiM2E3OGQ1MmU5Zjk0OGM5ODU2ZmMzYzc4NjZjYzdjMHAxNzc2OTc@steady-clam-77697.upstash.io"
21
+ REDIS_PORT=6379
22
+
23
+ # ---------- WEBHOOKS ----------
24
+ CALLBACK_URL="https://webhooksite.net/c93aac48-5237-4078-9511-14d778acba2f"
25
+ GRADE_WEBHOOK_URL="https://webhooksite.net/c93aac48-5237-4078-9511-14d778acba2f"
26
+
27
+ # ---------- BACKENDS ---------- Choose One
28
+ #generation
29
+ # OLLAMA | COHERE | MISTRAL | GEMINI | HUGGINGFACE | GROQ | OPENROUTER | DEEPSEEK |
30
+ #embedding
31
+ # OLLAMA | COHERE | MISTRAL | GEMINI | HUGGINGFACE
32
+
33
+ # ---------- OLLAMA ----------
34
+ OLLAMA_URL="http://localhost:11434"
35
+ # OLLAMA_API_KEY="getAone"
36
+ # GENERATION_BACKEND="OLLAMA"
37
+ # EMBEDDING_BACKEND="OLLAMA"
38
+ # GENERATION_MODEL_ID="deepseek-v3.1:671b-cloud"
39
+ # EMBEDDING_MODEL_ID="embeddinggemma:latest"
40
+ # EMBEDDING_MODEL_SIZE=768
41
+ # QDRANT_COLLECTION="768_docs"
42
+
43
+
44
+ # ---------- COHERE ----------
45
+ COHERE_API_KEY="getAone"
46
+ # GENERATION_BACKEND="COHERE"
47
+ # EMBEDDING_BACKEND="COHERE"
48
+ # GENERATION_MODEL_ID="command-a-03-2025"
49
+ # EMBEDDING_MODEL_ID="embed-multilingual-v3.0"
50
+ # EMBEDDING_MODEL_SIZE=1024
51
+ # QDRANT_COLLECTION="1024_docs"
52
+
53
+
54
+ # ---------- MISTRAL ----------
55
+ MISTRAL_API_KEY="getAone"
56
+ # GENERATION_BACKEND="MISTRAL"
57
+ # EMBEDDING_BACKEND="MISTRAL"
58
+ # GENERATION_MODEL_ID="mistral-small-2603"
59
+ # EMBEDDING_MODEL_ID="mistral-embed-2312"
60
+ # EMBEDDING_MODEL_SIZE=1024
61
+ # QDRANT_COLLECTION="1024_docs"
62
+
63
+ # ---------- GEMINI ----------
64
+ GEMINI_API_KEY="getAone"
65
+ GENERATION_BACKEND="GEMINI"
66
+ EMBEDDING_BACKEND="GEMINI"
67
+ GENERATION_MODEL_ID="gemini-2.5-flash"
68
+ EMBEDDING_MODEL_ID="gemini-embedding-001"
69
+ EMBEDDING_MODEL_SIZE=768
70
+ QDRANT_COLLECTION="768_docs"
71
+
72
+ # ---------- HUGGING FACE ----------
73
+ HF_API_KEY="getAone"
74
+ # GENERATION_BACKEND="HUGGINGFACE"
75
+ # EMBEDDING_BACKEND="HUGGINGFACE"
76
+ # GENERATION_MODEL_ID="Qwen/Qwen2.5-72B-Instruct"
77
+ # EMBEDDING_MODEL_ID="google/embeddinggemma-300m"
78
+ # EMBEDDING_MODEL_SIZE=768
79
+ # QDRANT_COLLECTION="768_docs"
80
+
81
+ # ---------- DEEPSEEK ---------- paid
82
+ DEEPSEEK_API_KEY="getAone"
83
+ # GENERATION_BACKEND="DEEPSEEK"
84
+ # EMBEDDING_BACKEND="COHERE"
85
+ # GENERATION_MODEL_ID="deepseek-chat"
86
+ # EMBEDDING_MODEL_ID="embed-multilingual-v3.0"
87
+ # EMBEDDING_MODEL_SIZE=1024
88
+ # QDRANT_COLLECTION="1024_docs"
89
+
90
+
91
+ # ---------- OPENAI ---------- paid
92
+ OPENAI_API_KEY=""
93
+ OPENAI_API_URL=""
94
+
95
+
96
+ # ---------- GROQ ----------not complete
97
+ GROQ_API_KEY=""
98
+
99
+ # ---------- OPENROUTER ----------not complete
100
+ OPENROUTER_API_KEY=""
101
+ OPENROUTER_SITE_URL="http://localhost"
102
+ OPENROUTER_APP_NAME="IntegraRAG"
103
+ OPENROUTER_SEARCH_MODEL="perplexity/sonar-online"
104
+
105
+
106
+
107
+ # ---------- DEFAULTS ----------
108
+ INPUT_DAFAULT_MAX_CHARACTERS=2048
109
+ GENERATION_DAFAULT_MAX_TOKENS=1200
110
+ GENERATION_DAFAULT_TEMPERATURE=0.3
111
+
112
+ # ---------- CHUNKING ----------
113
+ CHUNK_SIZE=700
114
+ CHUNK_OVERLAP=150
115
+ CHUNK_METHOD="recursive"
.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python virtual environment
2
+ venv/
3
+ .env/
4
+ __pycache__/
5
+ *.pyc
6
+ *.pyo
7
+ *.pyd
8
+
9
+ Code_Backups.txt
10
+ data/
11
+ # VSCode
12
+ .vscode/
13
+ .vs/
14
+
15
+ # OS
16
+ .DS_Store
17
+ Thumbs.db
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ─────────────────────────────────────────────
2
+ # IntegraRAG — Production Dockerfile
3
+ # Services bundled: FastAPI + Celery worker
4
+ # External deps: Redis, Qdrant (cloud/managed)
5
+ # ─────────────────────────────────────────────
6
+ FROM python:3.11-slim
7
+
8
+ # System deps
9
+ RUN apt-get update && apt-get install -y --no-install-recommends \
10
+ build-essential \
11
+ curl \
12
+ libmagic1 \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ WORKDIR /app
16
+
17
+ # Install Python deps first (layer cache)
18
+ COPY requirements.txt .
19
+ RUN pip install --no-cache-dir -r requirements.txt
20
+
21
+ # Copy application source
22
+ COPY . .
23
+
24
+ # ── Runtime env defaults (override via HF Secrets or docker run -e) ──
25
+ ENV PORT=7860 \
26
+ PYTHONUNBUFFERED=1 \
27
+ PYTHONDONTWRITEBYTECODE=1
28
+
29
+ # Hugging Face Spaces exposes port 7860
30
+ EXPOSE 7860
31
+
32
+ # Entrypoint: start Celery worker in background, then FastAPI
33
+ COPY docker-entrypoint.sh /docker-entrypoint.sh
34
+ RUN chmod +x /docker-entrypoint.sh
35
+
36
+ ENTRYPOINT ["/docker-entrypoint.sh"]
README.md CHANGED
@@ -1,10 +1,27 @@
1
  ---
2
- title: EXAM RAG API
3
- emoji: 💻
4
- colorFrom: gray
5
  colorTo: blue
6
  sdk: docker
 
7
  pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: IntegraRAG API
3
+ emoji: 🧠
4
+ colorFrom: indigo
5
  colorTo: blue
6
  sdk: docker
7
+ app_port: 7860
8
  pinned: false
9
  ---
10
 
11
+ # IntegraRAG RAG-Powered Exam & Assistant API
12
+
13
+ FastAPI backend with Celery workers for document-based Q&A, exam generation, and AI-graded exam submissions.
14
+
15
+
16
+ `conda create -n RAG_API python==3.11`
17
+ `conda activate RAG_API`
18
+ `pip install -r requirements.txt`
19
+
20
+ `docker run -p 6333:6333 qdrant/qdrant`
21
+ `docker run -d -p 6379:6379 redis:7`
22
+
23
+ # View The .env
24
+
25
+ `celery -A celery_app.celery_app worker -P threads --loglevel=info`
26
+ `uvicorn main:app --host 0.0.0.0 --port 8030 --reload`
27
+ `uvicorn webhook:app --reload`
celery_app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # celery_app.py
2
+ from celery import Celery
3
+ import redis
4
+ from config import get_settings
5
+
6
+ celery_app = Celery(
7
+ "assistant_worker",
8
+ broker=f"{get_settings().REDIS_HOST}:{get_settings().REDIS_PORT}/0",
9
+ backend=f"{get_settings().REDIS_HOST}:{get_settings().REDIS_PORT}/1",
10
+ include=['generation.ExamAnswer']
11
+ )
12
+
13
+ celery_app.conf.update(
14
+ task_serializer="json",
15
+ accept_content=["json"],
16
+ result_serializer="json",
17
+ task_track_started=True,
18
+ task_time_limit=60*60,
19
+ )
20
+
21
+ import worker.tasks
22
+ from generation.ExamAnswer import grade_exam_task
23
+ def clear_redis_backend():
24
+ r = redis.Redis(host=get_settings().REDIS_HOST, port=get_settings().REDIS_PORT, db=1)
25
+ r.flushdb()
26
+ print("Redis result backend cleared!")
27
+
28
+ @celery_app.on_after_configure.connect
29
+ def setup(sender, **kwargs):
30
+ clear_redis_backend()
config.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_settings import BaseSettings
2
+ from functools import lru_cache
3
+
4
+
5
+ class Settings(BaseSettings):
6
+ DEBUG: bool = False
7
+ APP_NAME: str
8
+ QDRANT_COLLECTION: str = "docs"
9
+
10
+ CustomLoaders: bool = None
11
+ QDRANT_TYPE: str = "docker"
12
+ QDRANT_DOCKER_URL: str = "http://localhost:6333"
13
+ QDRANT_API_KEY: str = None
14
+ CHUNK_SIZE: int = 1000
15
+ CHUNK_OVERLAP: int = None
16
+ CHUNK_METHOD: str = None
17
+ GRADE_WEBHOOK_URL: str = None
18
+ REDIS_HOST: str = "localhost"
19
+ REDIS_PORT: int = 6379
20
+ CALLBACK_URL: str = None
21
+
22
+ # ---------- BACKENDS ----------
23
+ GENERATION_BACKEND: str = "OLLAMA"
24
+ EMBEDDING_BACKEND: str = "OLLAMA"
25
+
26
+ # ---------- API KEYS ----------
27
+ OPENAI_API_KEY: str = None
28
+ OPENAI_API_URL: str = None
29
+
30
+ COHERE_API_KEY: str = None
31
+
32
+ OLLAMA_URL: str = "http://localhost:11434"
33
+ OLLAMA_API_KEY: str = None
34
+
35
+ MISTRAL_API_KEY: str = None
36
+
37
+ GROQ_API_KEY: str = None
38
+
39
+ OPENROUTER_API_KEY: str = None
40
+ OPENROUTER_SITE_URL: str = "http://localhost" # forwarded as HTTP-Referer
41
+ OPENROUTER_APP_NAME: str = "IntegraRAG" # forwarded as X-Title
42
+ OPENROUTER_SEARCH_MODEL: str = "perplexity/sonar-online"
43
+
44
+ HF_API_KEY: str = None
45
+
46
+ DEEPSEEK_API_KEY: str = None
47
+
48
+ GEMINI_API_KEY: str = None
49
+
50
+ # ---------- MODELS ----------
51
+ GENERATION_MODEL_ID: str = "deepseek-v3.1:671b-cloud"
52
+ EMBEDDING_MODEL_ID: str = "embeddinggemma:latest"
53
+ EMBEDDING_MODEL_SIZE: int = 768
54
+ INPUT_DAFAULT_MAX_CHARACTERS: int = None
55
+ GENERATION_DAFAULT_MAX_TOKENS: int = None
56
+ GENERATION_DAFAULT_TEMPERATURE: float = None
57
+
58
+ class Config:
59
+ env_file = ".env"
60
+
61
+
62
+ @lru_cache
63
+ def get_settings():
64
+ return Settings()
docker-entrypoint.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ echo "==> Starting Celery worker in background..."
5
+ celery -A celery_app.celery_app worker \
6
+ -P threads \
7
+ --loglevel=info \
8
+ --concurrency=4 &
9
+
10
+ echo "==> Starting FastAPI (uvicorn) on port ${PORT:-7860}..."
11
+ exec uvicorn main:app \
12
+ --host 0.0.0.0 \
13
+ --port "${PORT:-7860}" \
14
+ --workers 1
generation/AssistantRagGenerator.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ from pydantic import Field
3
+ from langchain_core.language_models import LLM
4
+ from langchain_core.runnables import RunnableBranch, RunnableLambda, RunnableParallel
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain_core.prompts import PromptTemplate
7
+ from stores.llm.LLMProviderFactory import LLMProviderFactory
8
+ from config import get_settings
9
+
10
+ class ProviderLLMWrapper(LLM):
11
+ provider: Any = Field(..., description="The wrapped LLM provider")
12
+ def _call(self, prompt: str, stop=None) -> str:
13
+ # Calls the underlying model and ensures a string is returned
14
+ result = self.provider.generate_text(prompt)
15
+ if result is None:
16
+ raise ValueError("LLM provider returned None (likely due to timeout or error)")
17
+ if isinstance(result, dict):
18
+ response = result.get("response")
19
+ if response is None:
20
+ raise ValueError(f"LLM provider returned dict without 'response' key: {result.keys()}")
21
+ return response
22
+ if isinstance(result, str):
23
+ return result
24
+ raise ValueError(f"Unexpected LLM response type: {type(result).__name__}")
25
+ @property
26
+ def _llm_type(self):
27
+ return "custom-provider"
28
+
29
+ def get_num_tokens(self, text: str) -> int:
30
+ return len(text.split())
31
+
32
+ class AssistantRagGen:
33
+ def __init__(self):
34
+ config = get_settings()
35
+ self.factory = LLMProviderFactory(config)
36
+ self.generator = self.factory.create(config.GENERATION_BACKEND)
37
+ self.generator.set_generation_model(config.GENERATION_MODEL_ID)
38
+ self.llm = ProviderLLMWrapper(provider=self.generator)
39
+ self.valid_routes = {"user_info", "site_query", "pdf_query"}
40
+
41
+ def build_router_prompt(self, user_prompt: str) -> str:
42
+ return f"""You are a query routing classifier. Your sole job is to categorize a user's question into exactly one routing category.
43
+
44
+ ## Categories
45
+
46
+ | Category | Routes questions about... |
47
+ |--------------|-------------------------------------------------------------------------------------------|
48
+ | `user_info` | Personal profile, enrolled courses, username, role, learning progress, achievements |
49
+ | `site_query` | Platform features, website navigation, rules, policies, FAQs, general platform knowledge |
50
+ | `pdf_query` | Document content, uploaded files, PDF search, lesson materials, reading resources |
51
+
52
+ ## Examples
53
+
54
+ user_info → "What courses am I enrolled in?"
55
+ user_info → "What is my current progress in the Python course?"
56
+ site_query → "How do I reset my password?"
57
+ site_query → "What are the platform's refund policies?"
58
+ pdf_query → "What does the document say about recursion?"
59
+ pdf_query → "Find me the section on neural networks in the materials"
60
+
61
+ ## Decision Rules
62
+
63
+ 1. If the question involves the **current user's personal data** → `user_info`
64
+ 2. If the question is about **how the platform works** → `site_query`
65
+ 3. If the question requires **reading or searching a document** → `pdf_query`
66
+ 4. When ambiguous, prefer `pdf_query` over `site_query`, and `user_info` over both.
67
+
68
+ ## Output Format
69
+
70
+ Respond with a single lowercase word. No punctuation. No explanation. No whitespace.
71
+
72
+ Valid outputs: user_info | site_query | pdf_query
73
+
74
+ Question: {user_prompt}
75
+ """
76
+
77
+ def build_unified_prompt(self, context: str, question: str, conversation_history: str = "", User_Info: str = "") -> str:
78
+ return f"""
79
+ You are a helpful university assistant.
80
+
81
+ Rules:
82
+ - Use the provided context FIRST.
83
+ - Use conversation history to understand follow-up questions.
84
+ - If the question is about the user, use the User_Info and enrolled_courses.
85
+ - If the answer is not in the context, say:
86
+ "Not found in the provided materials."
87
+ Then add:
88
+ "From my own information:" and answer briefly.
89
+ - Be concise and clear.
90
+
91
+ Conversation History:
92
+ {conversation_history if conversation_history else "None"}
93
+
94
+ User Info:
95
+ {User_Info if User_Info else "None"}
96
+
97
+ Context:
98
+ {context}
99
+
100
+ Current Question:
101
+ {question}
102
+
103
+ Answer:
104
+ """
105
+
106
+ def build_user_info_prompt(self, question: str, conversation_history: str = "", User_Info: str = "") -> str:
107
+ return f"""
108
+ You are a university assistant handling a user account inquiry.
109
+ Use the provided User Info and Enrolled Courses to answer the question accurately.
110
+
111
+ Conversation History:
112
+ {conversation_history if conversation_history else "None"}
113
+
114
+ User Info:
115
+ {User_Info if User_Info else "None"}
116
+
117
+ Current Question:
118
+ {question}
119
+
120
+ Answer:
121
+ """
122
+
123
+ def build_site_query_prompt(self, question: str,context:str="", conversation_history: str = "") -> str:
124
+ return f"""
125
+ You are a university assistant handling a platform or site-related question.
126
+ Provide clear instructions, rules, or general information about how the university platform works.
127
+
128
+ Conversation History:
129
+ {conversation_history if conversation_history else "None"}
130
+
131
+ Current Question:
132
+ {question}
133
+
134
+ Site Context:
135
+ {context if context else "None"}
136
+
137
+ Answer:
138
+ """
139
+
140
+ def robust_router(self, input_data: dict) -> str:
141
+ question = input_data["question"]
142
+ attempts = 0
143
+ while attempts < 3:
144
+ prompt = self.build_router_prompt(question)
145
+ route = self.llm.invoke(prompt).strip().lower()
146
+
147
+ if route in self.valid_routes:
148
+ return route
149
+ attempts += 1
150
+ return "pdf_query"
151
+
152
+ def get_chain(self):
153
+ router_node = RunnableLambda(self.robust_router)
154
+
155
+ user_info_chain = RunnableLambda(lambda x: self.llm.invoke(
156
+ self.build_user_info_prompt(
157
+ question=x["question"],
158
+ conversation_history=x.get("conversation_history", ""),
159
+ User_Info=x.get("User_Info", ""),
160
+ )
161
+ ))
162
+
163
+ site_query_chain = RunnableLambda(lambda x: self.llm.invoke(
164
+ self.build_site_query_prompt(
165
+ question=x["question"],
166
+ context=x.get("context", ""),
167
+ conversation_history=x.get("conversation_history", "")
168
+ )
169
+ ))
170
+
171
+ pdf_query_chain = RunnableLambda(lambda x: self.llm.invoke(
172
+ self.build_unified_prompt(
173
+ context=x.get("context", "No context provided."),
174
+ question=x["question"],
175
+ conversation_history=x.get("conversation_history", ""),
176
+ User_Info=x.get("User_Info", ""),
177
+ )
178
+ ))
179
+
180
+ branching_logic = RunnableBranch(
181
+ (lambda x: x["topic"] == "user_info", user_info_chain),
182
+ (lambda x: x["topic"] == "site_query", site_query_chain),
183
+ pdf_query_chain
184
+ )
185
+
186
+ full_chain = (
187
+ RunnableParallel({
188
+ "topic": router_node,
189
+ # Pass all incoming variables straight through to the branches
190
+ "question": lambda x: x["question"],
191
+ "context": lambda x: x.get("context", ""),
192
+ "conversation_history": lambda x: x.get("conversation_history", ""),
193
+ "User_Info": lambda x: x.get("User_Info", ""),
194
+ "enrolled_courses": lambda x: x.get("enrolled_courses", "")
195
+ })
196
+ | branching_logic
197
+ | StrOutputParser()
198
+ )
199
+
200
+ return full_chain
201
+
generation/ExamAnswer.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from datetime import datetime
3
+ from typing import List, Dict, Any
4
+ from celery import shared_task
5
+ import json
6
+ import re
7
+ import httpx
8
+
9
+ from generation.answer_models import (ExamSubmission,ExamResult,StudentAnswer,GradedAnswer,QuestionType)
10
+ from indexing.indexingController import IndexingController
11
+ from stores.llm.LLMProviderFactory import LLMProviderFactory
12
+ from config import get_settings
13
+
14
+
15
+ def calculate_grade(percentage: float) -> str:
16
+ if percentage >= 90:
17
+ return "A"
18
+ elif percentage >= 80:
19
+ return "B"
20
+ elif percentage >= 70:
21
+ return "C"
22
+ elif percentage >= 60:
23
+ return "D"
24
+ else:
25
+ return "F"
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ class ExamGradingService:
31
+ def __init__(self, use_ai_for_essays: bool = True):
32
+ self.use_ai_for_essays = use_ai_for_essays
33
+
34
+ config = get_settings()
35
+
36
+ factory = LLMProviderFactory(config)
37
+ provider = factory.create(config.GENERATION_BACKEND)
38
+ provider.set_generation_model(config.GENERATION_MODEL_ID)
39
+ self.llm = provider
40
+
41
+ self.semantic_threshold = 0.65
42
+ self.high_confidence = 0.85
43
+
44
+ def grade_submission(self, submission: ExamSubmission) -> ExamResult:
45
+ graded_answers: List[GradedAnswer] = []
46
+ total_score = 0
47
+ max_total_score = 0
48
+
49
+ for ans in submission.answers:
50
+ correct_answer = None
51
+ if ans.metadata:
52
+ correct_answer = ans.metadata.get("correct_answer")
53
+
54
+ graded = self.grade_answer(ans, correct_answer,submission.course_id)
55
+ graded_answers.append(graded)
56
+ total_score += graded.score
57
+ max_total_score += graded.max_score
58
+
59
+ percentage = (total_score / max_total_score) * 100 if max_total_score else 0
60
+ grade = calculate_grade(percentage)
61
+
62
+ return ExamResult(
63
+ exam_id=submission.exam_id,
64
+ student_id=submission.student_id,
65
+ student_name=submission.student_name,
66
+ graded_answers=graded_answers,
67
+ total_score=total_score,
68
+ max_total_score=max_total_score,
69
+ percentage=percentage,
70
+ grade=grade,
71
+ feedback_summary="RAG based grading using LLM evaluation",
72
+ submission_time=submission.submission_time,
73
+ graded_time=datetime.utcnow().isoformat()
74
+ )
75
+
76
+ def grade_answer(self, answer: StudentAnswer, correct_answer: Any, course) -> GradedAnswer:
77
+ if answer.question_type in [QuestionType.MULTIPLE_CHOICE,QuestionType.TRUE_FALSE]:
78
+ student_str = str(answer.student_response).strip().lower()
79
+ if answer.question_type == QuestionType.TRUE_FALSE:
80
+ if isinstance(correct_answer, bool):
81
+ correct_bool = correct_answer
82
+ elif isinstance(correct_answer, str):
83
+ correct_bool = correct_answer.lower() in ['true', 't', '1', 'yes', 'True']
84
+ else:
85
+ correct_bool = bool(correct_answer)
86
+ student_bool = student_str in ['true', 't', '1', 'yes']
87
+ is_correct = student_bool == correct_bool
88
+ score = answer.max_score if is_correct else 0
89
+ feedback = "Exact match grading"
90
+ else: # multiple_choice
91
+ correct_str = str(correct_answer).strip().lower() if correct_answer else ""
92
+ is_correct = student_str == correct_str
93
+ score = answer.max_score if is_correct else 0
94
+ feedback = "Exact match grading"
95
+ else:
96
+ if self.use_ai_for_essays and correct_answer:
97
+ score, feedback = self.ai_semantic_grade(
98
+ answer.question_text,
99
+ answer.student_response,
100
+ correct_answer,
101
+ answer.max_score,
102
+ course=course
103
+ )
104
+ is_correct = score >= (answer.max_score * self.semantic_threshold)
105
+ else:
106
+ similarity = self.simple_similarity(
107
+ answer.student_response,
108
+ correct_answer
109
+ )
110
+ score = similarity * answer.max_score
111
+ is_correct = similarity >= self.semantic_threshold
112
+ feedback = f"Similarity score {similarity:.2f}"
113
+
114
+ return GradedAnswer(
115
+ question_no=answer.question_no,
116
+ question_type=answer.question_type,
117
+ question_text=answer.question_text,
118
+ student_response=answer.student_response,
119
+ correct_answer=correct_answer,
120
+ score=score,
121
+ max_score=answer.max_score,
122
+ feedback=feedback,
123
+ is_correct=is_correct
124
+ )
125
+
126
+ def simple_similarity(self, student: str, correct: str) -> float:
127
+ if not student or not correct:
128
+ return 0
129
+ student_words = set(student.lower().split())
130
+ correct_words = set(correct.lower().split())
131
+ intersection = student_words.intersection(correct_words)
132
+ union = student_words.union(correct_words)
133
+ return len(intersection) / len(union)
134
+
135
+ def retrieve_context(self, question: str, course:str):
136
+ """
137
+ Retrieve relevant context from Qdrant for a given question filtered by course
138
+ Args: question: The question text to embed and search for // course: Optional course filter
139
+ Returns: String containing concatenated context from top 3 chunks
140
+ """
141
+ try:
142
+ controller = IndexingController()
143
+ embedding = controller.embedder.embed_text(question)
144
+
145
+ # Build metadata filters course
146
+ filters = []
147
+ if course:
148
+ filters.append({
149
+ "field": "course",
150
+ "op": "eq",
151
+ "value": course,
152
+ "clause": "must"
153
+ })
154
+
155
+ # Query Qdrant with filters
156
+ results = controller.vector_store.query_qdrant(embedding=embedding,filters=filters,top_k=5)
157
+
158
+ context = "\n".join(r["content"] for r in results if r.get("content"))
159
+
160
+ logger.info(f"Retrieved {len(results)} chunks for question (filtered by course={course})")
161
+ return context
162
+
163
+ except Exception as e:
164
+ logger.error(f"Context retrieval failed: {e}")
165
+ return ""
166
+
167
+ def build_prompt(self, question, student_answer, correct_answer, context):
168
+ return f"""
169
+ You are an academic exam grader.
170
+
171
+ Question:
172
+ {question}
173
+
174
+ Correct Answer:
175
+ {correct_answer}
176
+
177
+ Reference Material:
178
+ {context}
179
+
180
+ Student Answer:
181
+ {student_answer}
182
+
183
+ Evaluate the student answer using semantic similarity.
184
+ You may slightly use your knowledge if correct answer not in Reference Material.
185
+
186
+ Return JSON only:
187
+
188
+ {{
189
+ "score": number between 0 and 1,
190
+ "feedback": short explanation
191
+ }}
192
+ """
193
+
194
+ def parse_llm_output(self, text: str):
195
+ try:
196
+ if isinstance(text, dict):
197
+ if 'response' in text:
198
+ text = text['response']
199
+ else:
200
+ text = str(text)
201
+ elif hasattr(text, 'content'):
202
+ text = text.content
203
+ elif hasattr(text, 'text'):
204
+ text = text.text
205
+ text = str(text).strip()
206
+ if not text:
207
+ return 0, "Empty response from LLM"
208
+ text = re.sub(r'```json\s*|\s*```', '', text)
209
+ try:
210
+ data = json.loads(text)
211
+ except json.JSONDecodeError:
212
+ json_match = re.search(r'\{.*\}', text, re.DOTALL)
213
+ if json_match:
214
+ data = json.loads(json_match.group())
215
+ else:
216
+ raise
217
+
218
+ score = float(data.get("score", 0))
219
+ feedback = data.get("feedback", "")
220
+ score = max(0, min(score, 1))
221
+ return score, feedback
222
+
223
+ except Exception as e:
224
+ logger.error(f"Failed to parse LLM output: {e}, text type: {type(text)}")
225
+ return 0, "Failed to parse AI grading"
226
+
227
+ def ai_semantic_grade(self, question, student, correct, max_score, course):
228
+ """
229
+ Grade an answer using AI with context from Qdrant.
230
+ Args: question: The question text // student: Student's answer // correct: Correct answer
231
+ max_score: Maximum score for this question // course: Optional course for filtering context
232
+ Returns: // Tuple of (score, feedback)
233
+ """
234
+ try:
235
+ # Retrieve context filtered by username and course
236
+ context = self.retrieve_context(question, course)
237
+
238
+ prompt = self.build_prompt(question,student,correct,context)
239
+
240
+ response = self.llm.generate_text(prompt)
241
+
242
+ # Log response type for debugging
243
+ logger.info(f"Response type: {type(response)}")
244
+
245
+ score_ratio, feedback = self.parse_llm_output(response)
246
+ score = score_ratio * max_score
247
+
248
+ return score, feedback
249
+
250
+ except Exception as e:
251
+ logger.error(f"AI grading failed: {e}")
252
+ # Fallback to simple similarity
253
+ similarity = self.simple_similarity(student, correct)
254
+ return similarity * max_score, f"Fallback similarity grading: {similarity:.2f}"
255
+
256
+ @shared_task
257
+ def grade_exam_task(submission_dict: Dict[str, Any]):
258
+ submission = None
259
+ try:
260
+ submission = ExamSubmission(**submission_dict)
261
+ service = ExamGradingService()
262
+ result = service.grade_submission(submission)
263
+ result_dict = result.model_dump()
264
+
265
+ # Send webhook with grade only
266
+ try:
267
+ webhook_url = get_settings().GRADE_WEBHOOK_URL
268
+ print(f" Webhook URL: {webhook_url}")
269
+
270
+ if webhook_url:
271
+ # Create grade-only payload
272
+ grade_only_payload = {
273
+ "status": "completed",
274
+ "exam_id": submission.exam_id,
275
+ "student_id": submission.student_id,
276
+ "course_id":submission.course_id,
277
+ "grade": {
278
+ "total_score": result_dict['total_score'],
279
+ "max_total_score": result_dict['max_total_score'],
280
+ "percentage": result_dict['percentage'],
281
+ "grade": result_dict['grade'],
282
+ "graded_time": result_dict['graded_time']
283
+ },
284
+ "result" : result_dict,
285
+ }
286
+
287
+ response = httpx.post(
288
+ webhook_url,
289
+ json=grade_only_payload,
290
+ timeout=30.0
291
+ )
292
+ print(f" Response status: {response.status_code}")
293
+
294
+ if response.status_code == 200:
295
+ print(" Grade-only webhook sent successfully!")
296
+ else:
297
+ print(f" Webhook returned status: {response.status_code}")
298
+ print(f" Response: {response.text[:200]}")
299
+ else:
300
+ print("WEBHOOK_URL is empty or not set!")
301
+
302
+ except Exception as e:
303
+ print(f" Webhook error: {type(e).__name__}: {e}")
304
+ import traceback
305
+ traceback.print_exc()
306
+
307
+ print(" Task completed successfully")
308
+ return result_dict
309
+
310
+ except Exception as e:
311
+ print(f" ERROR in task: {type(e).__name__}: {e}")
312
+ import traceback
313
+ traceback.print_exc()
314
+ raise
generation/ExamRagGenerator.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ import logging
3
+ import json
4
+ import re
5
+ import math
6
+ from json_repair import repair_json
7
+ from pydantic import parse_obj_as
8
+ from collections import defaultdict
9
+ from config import get_settings
10
+ from routes.schemas.Exam_Models import *
11
+ from stores.llm.LLMProviderFactory import LLMProviderFactory
12
+ from generation.AssistantRagGenerator import ProviderLLMWrapper
13
+ from generation.prompts import ExamPromptBuilder
14
+ from indexing.indexingController import IndexingController
15
+
16
+ class ExamService:
17
+ MAX_CHUNK_CHARS = 2000
18
+ MAX_TOTAL_CONTEXT = 8000
19
+ MAX_SCORE = 40
20
+ PASS_THRESHOLD = int(MAX_SCORE * 0.8)
21
+ MAX_GENERATION_ATTEMPTS = 3
22
+
23
+ def __init__(self):
24
+ self.logger = logging.getLogger(__name__)
25
+ self._models_initialized = False
26
+ self.settings=get_settings()
27
+ self._init_models()
28
+ self.prompts=ExamPromptBuilder()
29
+ self.controller = IndexingController()
30
+ self.store = self.controller.vector_store
31
+ self.BATCH_SIZE=10
32
+
33
+ def _init_models(self):
34
+ if self._models_initialized:
35
+ return
36
+ factory = LLMProviderFactory(self.settings)
37
+ self.generator = factory.create(self.settings.GENERATION_BACKEND)
38
+ self.generator.set_generation_model(self.settings.GENERATION_MODEL_ID)
39
+ self.embedding_provider = factory.create(self.settings.EMBEDDING_BACKEND)
40
+ self.embedding_provider.set_embedding_model(
41
+ self.settings.EMBEDDING_MODEL_ID,
42
+ self.settings.EMBEDDING_MODEL_SIZE
43
+ )
44
+ self.llm = ProviderLLMWrapper(provider=self.generator)
45
+ self._models_initialized = True
46
+
47
+ def _extract_json(self, text: str) -> dict:
48
+ """
49
+ Extract the first valid JSON object from LLM output. Attempts to repair malformed JSON using `repair_json`.
50
+ """
51
+ match = re.search(r"\{.*\}", text, re.DOTALL)
52
+ if not match:
53
+ self.logger.error("No JSON found in LLM response:\n%s", text)
54
+ raise ValueError("LLM returned no JSON")
55
+ json_str = match.group(0)
56
+ # Try to load directly
57
+ try:
58
+ return json.loads(json_str)
59
+ except json.JSONDecodeError:
60
+ self.logger.warning("Invalid JSON extracted, attempting repair...")
61
+ try:
62
+ repaired_str = repair_json(json_str)
63
+ return json.loads(repaired_str)
64
+ except Exception as e:
65
+ self.logger.error("Failed to repair JSON:\n%s\nError: %s", json_str, e)
66
+ raise
67
+
68
+ def normalize_exam_dict(self, data: dict):
69
+ # Normalize difficulty enum
70
+ if "difficulty" in data:
71
+ diff = data["difficulty"]
72
+ if isinstance(diff, str):
73
+ if "." in diff:
74
+ diff = diff.split(".")[-1]
75
+ data["difficulty"] = diff.lower()
76
+ # Normalize questions
77
+ questions = data.get("questions")
78
+ if not isinstance(questions, list):
79
+ return data
80
+ normalized_questions = []
81
+ for q in questions:
82
+ if not isinstance(q, dict):
83
+ continue
84
+ q.pop("id", None)
85
+ q.pop("question_id", None)
86
+ q.pop("points", None)
87
+
88
+ # normalize type
89
+ q_type = q.get("type")
90
+ if isinstance(q_type, str):
91
+ q_type = q_type.lower().strip()
92
+ if q_type == "truefalse":
93
+ q_type = "true_false"
94
+ q["type"] = q_type
95
+
96
+ # normalize question text
97
+ if "question" in q:
98
+ q["question"] = str(q["question"]).strip()
99
+
100
+ # MCQ normalization
101
+ if q_type == "mcq":
102
+ options = q.get("options")
103
+ # dict -> list
104
+ if isinstance(options, dict):
105
+ options = list(options.values())
106
+ # string -> split into options
107
+ elif isinstance(options, str):
108
+ parts = re.split(r"[A-D]\)|\n|\r", options)
109
+ options = [
110
+ p.strip(" .-")
111
+ for p in parts
112
+ if p.strip()
113
+ ]
114
+ # ensure list[str]
115
+ if isinstance(options, list):
116
+ options = [str(o).strip() for o in options]
117
+ else:
118
+ options = []
119
+ q["options"] = options
120
+
121
+ # normalize correct answer
122
+ correct = q.get("correct_answer")
123
+ if correct is not None:
124
+ correct = str(correct).strip()
125
+ q["correct_answer"] = correct
126
+ # ensure correct answer exists in options
127
+ if correct not in q["options"]:
128
+ q["options"].append(correct)
129
+ # ensure explanation exists
130
+ q.setdefault("explanation", "")
131
+
132
+ # True/False normalization
133
+ elif q_type == "true_false":
134
+ ans = q.get("correct_answer")
135
+ if isinstance(ans, str):
136
+ ans = ans.lower()
137
+ if ans in ["true", "t", "1", "yes"]:
138
+ ans = True
139
+ elif ans in ["false", "f", "0", "no"]:
140
+ ans = False
141
+ q["correct_answer"] = ans
142
+ q.setdefault("explanation", "")
143
+
144
+ # Short Answer normalization
145
+ elif q_type == "short_answer":
146
+ if "answer" in q:
147
+ q["answer"] = str(q["answer"]).strip()
148
+ q.setdefault("explanation", "")
149
+
150
+ # Essay normalization
151
+ elif q_type == "essay":
152
+ if "expected_keywords" in q:
153
+ keywords = q.pop("expected_keywords")
154
+ if isinstance(keywords, list):
155
+ q["answer_guidelines"] = ", ".join(keywords)
156
+ else:
157
+ q["answer_guidelines"] = str(keywords)
158
+ q.setdefault("answer_guidelines", "")
159
+
160
+ # Code question normalization
161
+ elif q_type == "code":
162
+ if "solution" in q:
163
+ q["solution"] = str(q["solution"])
164
+ q.setdefault("starter_code", None)
165
+ q.setdefault("explanation", "")
166
+ normalized_questions.append(q)
167
+ data["questions"] = normalized_questions
168
+
169
+ return data
170
+
171
+ def generate_exam(self, request: ExamGenerationRequest, context: str, llm, batch_size: int) -> List[QuestionUnion]:
172
+ """
173
+ Generate a batch of questions from the LLM, ensuring valid QuestionUnion objects.Repairs incomplete MCQs automatically.
174
+ """
175
+ # Prepare the prompt for the batch
176
+ batch_request = request.model_copy()
177
+ batch_request.total_questions = batch_size
178
+
179
+ prompt = self.prompts.build_exam_generation_prompt(batch_request, context)
180
+ raw_text = llm._call(prompt)
181
+
182
+ if not raw_text:
183
+ raise RuntimeError("LLM generation failed")
184
+
185
+ cleaned = re.sub(r"```[a-zA-Z]*|```", "", raw_text).strip()
186
+
187
+ try:
188
+ exam_dict = self._extract_json(cleaned)
189
+ exam_dict = self.normalize_exam_dict(exam_dict)
190
+
191
+ questions = exam_dict.get("questions") or []
192
+ questions = questions[:batch_size]
193
+
194
+ # Repair incomplete MCQs or missing fields
195
+ repaired_questions = []
196
+ for q in questions:
197
+ if not isinstance(q, dict):
198
+ continue # skip invalid entries
199
+ q_type = q.get("type")
200
+ if q_type == "mcq":
201
+ if not q.get("options"):
202
+ self.logger.warning(f"Skipping MCQ with no options: {q}")
203
+ continue
204
+ if not q.get("correct_answer"):
205
+ q["correct_answer"] = q["options"][0] # safe placeholder
206
+ repaired_questions.append(q)
207
+
208
+ # Convert to Pydantic QuestionUnion objects
209
+ questions = parse_obj_as(List[QuestionUnion], repaired_questions)
210
+
211
+ self.logger.info(
212
+ "Batch requested=%d | received=%d | kept=%d",
213
+ batch_size,
214
+ len(exam_dict.get("questions", [])),
215
+ len(questions),
216
+ )
217
+
218
+ except json.JSONDecodeError:
219
+ self.logger.error("Invalid JSON from LLM:\n%s", raw_text)
220
+ raise
221
+
222
+ return questions
223
+
224
+ def evaluate_exam(self, request: ExamGenerationRequest, exam: ExamResponse, llm):
225
+ prompt = self.prompts.build_exam_evaluation_prompt(request, exam)
226
+ raw_text = llm._call(prompt)
227
+
228
+ if not raw_text:
229
+ raise RuntimeError("Evaluation generation failed")
230
+
231
+ cleaned = re.sub(r"```[a-zA-Z]*|```", "", raw_text).strip()
232
+
233
+ try:
234
+ evaluation_dict = self._extract_json(cleaned)
235
+ except json.JSONDecodeError:
236
+ self.logger.error("Invalid evaluation JSON:\n%s", raw_text)
237
+ raise
238
+
239
+ return EvaluationResult.model_validate(evaluation_dict)
240
+
241
+ def split_chunks_by_topic_batches(self, exam_chunks, num_batches):
242
+
243
+ self.logger.info(f"Topics retrieved: {list(exam_chunks.keys())}")
244
+ self.logger.info(f"Number of batches: {num_batches}")
245
+
246
+ batches = [[] for _ in range(num_batches)]
247
+
248
+ for topic, chunks in exam_chunks.items():
249
+ total_chunks = len(chunks)
250
+ self.logger.info(f"Topic '{topic}' -> {total_chunks} chunks distributed across batches")
251
+
252
+ for idx, chunk in enumerate(chunks):
253
+ batch_index = idx % num_batches
254
+ batches[batch_index].append(chunk)
255
+
256
+ # Log batch composition
257
+ for i, batch in enumerate(batches):
258
+ topic_counter = defaultdict(int)
259
+ for chunk in batch:
260
+ topic = chunk.get("metadata", {}).get("topic", "unknown")
261
+ topic_counter[topic] += 1
262
+ self.logger.info(f"Batch {i+1} contains {len(batch)} chunks -> {dict(topic_counter)}")
263
+
264
+ return batches
265
+
266
+
267
+ def exam_task(self, request_dict: dict) -> ExamResponse:
268
+ """
269
+ Generate a full exam using batching, safety break, and validated QuestionUnion questions.Each batch receives a portion of the retrieved chunks.
270
+ """
271
+ request = ExamGenerationRequest.model_validate(request_dict)
272
+ # Prepare context from knowledge store
273
+ topics_with_embeddings = self.prepare_topics_with_embeddings(request.topics)
274
+ exam_chunks = self.store.retrieve_for_exam(topics_with_embeddings,request.username,request.course,request.references)
275
+
276
+ # Determine number of batches
277
+ num_batches = math.ceil(request.total_questions / self.BATCH_SIZE)
278
+ self.logger.info(f"Raw exam_chunks structure: {type(exam_chunks)}")
279
+
280
+ for k, v in exam_chunks.items():
281
+ self.logger.info(f"Topic={k} | type={type(v)} | len={len(v) if hasattr(v,'__len__') else 'NA'}")
282
+
283
+ chunk_batches = self.split_chunks_by_topic_batches(exam_chunks,num_batches)
284
+
285
+ feedback_context = ""
286
+
287
+ best_exam = None
288
+ best_score = 0
289
+
290
+ for attempt in range(self.MAX_GENERATION_ATTEMPTS):
291
+ self.logger.info(f"Generating exam attempt {attempt+1}")
292
+
293
+ remaining_distribution: Dict[QuestionType, int] = dict(request.question_types_distribution)
294
+ all_questions: List[QuestionUnion] = []
295
+ batch_index = 0
296
+
297
+ # Batch generation loop
298
+ while len(all_questions) < request.total_questions:
299
+ remaining = request.total_questions - len(all_questions)
300
+ batch_size = min(self.BATCH_SIZE, remaining)
301
+ # Determine batch distribution
302
+ batch_distribution: Dict[QuestionType, int] = {}
303
+ slots_left = batch_size
304
+
305
+ for qtype, count in remaining_distribution.items():
306
+ if slots_left <= 0:
307
+ break
308
+
309
+ take = min(count, slots_left)
310
+
311
+ if take > 0:
312
+ batch_distribution[qtype] = take
313
+ slots_left -= take
314
+
315
+ if not batch_distribution:
316
+ break
317
+
318
+ batch_request = request.model_copy()
319
+ batch_request.total_questions = sum(batch_distribution.values())
320
+ batch_request.question_types_distribution = batch_distribution
321
+
322
+ # Select chunk subset for this batch
323
+
324
+ chunk_subset = chunk_batches[batch_index % len(chunk_batches)]
325
+ self.logger.info(f"\n===== BATCH {batch_index+1} CHUNKS =====")
326
+
327
+ for i, chunk in enumerate(chunk_subset):
328
+
329
+ meta = chunk.get("metadata", {})
330
+ topic = meta.get("topic", "unknown")
331
+ page = meta.get("page", "NA")
332
+
333
+ # Try common text keys
334
+ text = chunk.get("text") or chunk.get("content") or chunk.get("page_content") or ""
335
+
336
+ preview = text[:200].replace("\n", " ")
337
+
338
+ self.logger.info(
339
+ f"Chunk {i+1} | Topic={topic} | Page={page} | Preview={preview}"
340
+ )
341
+
342
+ self.logger.info("=====================================\n")
343
+
344
+ batch_index += 1
345
+
346
+ batch_context = self.build_exam_context(chunk_subset)
347
+
348
+ if feedback_context:
349
+ batch_context += f"\n\nEvaluator Feedback:\n{feedback_context}"
350
+
351
+ # Generate questions
352
+
353
+ batch_questions = self.generate_exam(batch_request,batch_context,self.llm,batch_request.total_questions)
354
+
355
+ # Filter generated questions
356
+ for q in batch_questions:
357
+ if remaining_distribution.get(q.type, 0) > 0:
358
+ all_questions.append(q)
359
+ remaining_distribution[q.type] -= 1
360
+ if len(all_questions) >= request.total_questions:
361
+ break
362
+
363
+ # Build final exam
364
+
365
+ exam_dict = {
366
+ "exam_id": request.exam_id,
367
+ "difficulty": request.difficulty,
368
+ "total_questions": request.total_questions,
369
+ "expected_distribution": request.question_types_distribution,
370
+ "questions": all_questions[:request.total_questions],
371
+ }
372
+
373
+ try:
374
+ exam = ExamResponse.model_validate(exam_dict)
375
+ except Exception as e:
376
+ self.logger.error(f"Exam validation failed: {e}")
377
+ raise
378
+
379
+ evaluation = self.evaluate_exam(request, exam, self.llm)
380
+ self.logger.info(f"Evaluation score: {evaluation.overall_score}")
381
+
382
+ if evaluation.overall_score > best_score:
383
+ best_score = evaluation.overall_score
384
+ best_exam = exam
385
+
386
+ if evaluation.overall_score >= self.PASS_THRESHOLD:
387
+ break
388
+
389
+ feedback_context = evaluation.feedback
390
+
391
+ if best_exam is None:
392
+ raise RuntimeError("Exam generation failed after retries")
393
+
394
+ return best_exam
395
+
396
+
397
+
398
+ def build_exam_context(self, exam_chunks) -> str:
399
+ """
400
+ Accepts either:
401
+ 1) {topic: [chunks]}
402
+ 2) [chunks]
403
+ """
404
+
405
+ # Normalize structure
406
+ if isinstance(exam_chunks, list):
407
+ topic_chunks = defaultdict(list)
408
+
409
+ for c in exam_chunks:
410
+ topic = c.get("metadata", {}).get("topic", "Unknown")
411
+ topic_chunks[topic].append(c)
412
+
413
+ exam_chunks = topic_chunks
414
+
415
+ context_parts = []
416
+ total_length = 0
417
+
418
+ for topic, chunks in exam_chunks.items():
419
+
420
+ topic_header = f"\n### Topic: {topic}\n"
421
+
422
+ if total_length + len(topic_header) > self.MAX_TOTAL_CONTEXT:
423
+ break
424
+
425
+ context_parts.append(topic_header)
426
+ total_length += len(topic_header)
427
+
428
+ for c in chunks:
429
+
430
+ text = c.get("payload", {}).get("text", "")
431
+ source = c.get("metadata", {}).get("source", "")
432
+ bookmark = c.get("metadata", {}).get("bookmark_path", "")
433
+
434
+ if not isinstance(text, str):
435
+ continue
436
+
437
+ if len(text) > self.MAX_CHUNK_CHARS:
438
+ text = text[:self.MAX_CHUNK_CHARS]
439
+
440
+ formatted_chunk = (f"[Source: {source} | Bookmark: {bookmark}]\n{text}\n")
441
+
442
+ if total_length + len(formatted_chunk) > self.MAX_TOTAL_CONTEXT:
443
+ break
444
+
445
+ context_parts.append(formatted_chunk)
446
+ total_length += len(formatted_chunk)
447
+
448
+ return "\n".join(context_parts)
449
+
450
+
451
+ def prepare_topics_with_embeddings(self, topics: List[str]):
452
+ results = []
453
+ for topic in topics:
454
+ try:
455
+ embedding = self.embedding_provider.embed_text(topic)
456
+ results.append((topic, embedding))
457
+ except Exception as e:
458
+ self.logger.warning(f"Embedding failed for topic '{topic}': {e}")
459
+ self.logger.info(f"Prepared {len(results)} topic embeddings")
460
+ return results
generation/__init__.py ADDED
File without changes
generation/answer_models.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import List, Optional, Dict, Any, Union
3
+ from enum import Enum
4
+
5
+ class QuestionType(str, Enum):
6
+ MULTIPLE_CHOICE = "multiple_choice"
7
+ TRUE_FALSE = "true_false"
8
+ SHORT_ANSWER = "short_answer"
9
+ ESSAY = "essay"
10
+ CODE = "code"
11
+
12
+ class StudentAnswer(BaseModel):
13
+ question_no: int
14
+ question_type: QuestionType
15
+ question_text: str
16
+ student_response: str
17
+ max_score: float = 1.0
18
+ metadata: Optional[Dict[str, Any]] = {}
19
+
20
+ class GradedAnswer(BaseModel):
21
+ question_no: int
22
+ question_type: QuestionType
23
+ question_text: str
24
+ student_response: str
25
+ correct_answer: Optional[Any]
26
+ score: float
27
+ max_score: float
28
+ feedback: str
29
+ is_correct: bool
30
+
31
+ class ExamSubmission(BaseModel):
32
+ exam_id: str
33
+ course_id: str
34
+ student_id: str
35
+ student_name: Optional[str]
36
+ answers: List[StudentAnswer]
37
+ submission_time: str
38
+ metadata: Optional[Dict[str, Any]] = {}
39
+
40
+ class ExamResult(BaseModel):
41
+ exam_id: str
42
+ student_id: str
43
+ student_name: Optional[str]
44
+ graded_answers: List[GradedAnswer]
45
+ total_score: float
46
+ max_total_score: float
47
+ percentage: float
48
+ grade: Optional[str]
49
+ feedback_summary: Optional[str]
50
+ submission_time: str
51
+ graded_time: str
generation/parsing_utils.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from typing import Any, Dict, Optional
4
+
5
+ logger = logging.getLogger("ExamGraph")
6
+
7
+ def safe_parse(parser_obj, text: str, question_no: int) -> Optional[Dict[str, Any]]:
8
+ if not text or text.strip() in ("null", "None", ""):
9
+ logger.warning(f"[Parse] q{question_no}: empty/null response")
10
+ return None
11
+
12
+ last_error = None
13
+
14
+ # Try direct parse
15
+ try:
16
+ result = parser_obj.parse(text)
17
+ return result.model_dump() if hasattr(result, "model_dump") else result
18
+ except Exception as e:
19
+ last_error = e
20
+ logger.debug(f"[Parse] q{question_no}: direct parse failed, trying extraction")
21
+
22
+ # Try to extract JSON from text (LLM may have wrapped it in prose)
23
+ try:
24
+ # look for {...} pattern
25
+ start = text.rfind("{")
26
+ end = text.rfind("}") + 1
27
+ if start >= 0 and end > start:
28
+ json_str = text[start:end]
29
+ json_obj = json.loads(json_str)
30
+ result = parser_obj.parse(json.dumps(json_obj))
31
+ return result.model_dump() if hasattr(result, "model_dump") else result
32
+ except Exception as e:
33
+ last_error = e
34
+ logger.debug(f"[Parse] q{question_no}: json extraction failed")
35
+
36
+ # Last resort: if it looks like partial JSON, mark for regen
37
+ error_msg = str(last_error) if last_error else "unknown"
38
+ logger.error(f"[Parse] q{question_no}: failed all attempts: {error_msg}")
39
+ return None
40
+
41
+ def categorize_error(error_str: str) -> str:
42
+ err = error_str.lower()
43
+ if "timeout" in err:
44
+ return "timeout"
45
+ elif "json" in err or "invalid" in err:
46
+ return "invalid_json"
47
+ elif "field required" in err or "missing" in err:
48
+ return "missing_field"
49
+ elif "none" in err or "null" in err:
50
+ return "null"
51
+ return "unknown"
generation/prompts.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from generation.ExamRagGenerator import ExamGenerationRequest, ExamResponse
2
+ import json
3
+
4
+ class ExamPromptBuilder:
5
+ MAX_SCORE = 40
6
+
7
+ def build_exam_generation_prompt(self,request: ExamGenerationRequest,context: str) -> str:
8
+ distribution = {
9
+ q_type.value: count
10
+ for q_type, count in request.question_types_distribution.items()
11
+ }
12
+
13
+ return f"""
14
+ You are an automated exam generation system.
15
+
16
+ Your job is to produce a structured exam strictly following the schema below.
17
+
18
+ ----------------------------------------------------
19
+ CRITICAL OUTPUT RULES
20
+ ----------------------------------------------------
21
+
22
+ You MUST return ONLY a valid JSON object.
23
+
24
+ Do NOT include:
25
+
26
+ - explanations
27
+ - markdown
28
+ - comments
29
+ - code blocks
30
+ - text before or after the JSON
31
+
32
+ The response MUST start with {{ and end with }}.
33
+
34
+ If the output is not valid JSON the result will be rejected.
35
+
36
+ ----------------------------------------------------
37
+ ENUM VALUES (STRICT)
38
+ ----------------------------------------------------
39
+
40
+ difficulty must be exactly one of:
41
+
42
+ easy
43
+ medium
44
+ hard
45
+
46
+ question type must be exactly one of:
47
+
48
+ mcq
49
+ true_false
50
+ short_answer
51
+ essay
52
+ code
53
+
54
+ ----------------------------------------------------
55
+ EXAM REQUIREMENTS
56
+ ----------------------------------------------------
57
+
58
+ course: {request.course}
59
+
60
+ difficulty: {request.difficulty.value}
61
+
62
+ total_questions: {request.total_questions}
63
+
64
+ question_types_distribution:
65
+ {json.dumps(distribution)}
66
+
67
+ You MUST generate exactly:
68
+
69
+ {json.dumps(distribution)}
70
+
71
+ Example:
72
+
73
+ {{
74
+ "mcq": 3,
75
+ "essay": 2
76
+ }}
77
+
78
+ means exactly:
79
+ 3 mcq questions
80
+ 2 essay questions
81
+
82
+ ----------------------------------------------------
83
+ CONTEXT
84
+ ----------------------------------------------------
85
+
86
+ Use ONLY the information from this context when creating questions.
87
+
88
+ {context}
89
+
90
+ ----------------------------------------------------
91
+ QUESTION RULES
92
+ ----------------------------------------------------
93
+
94
+ MCQ QUESTIONS
95
+
96
+ - must contain exactly 4 options
97
+ - options must be plain text
98
+ - correct_answer must match one option EXACTLY
99
+ - do NOT use letters like A/B/C/D
100
+ - do NOT include numbering inside options
101
+
102
+ Example:
103
+
104
+ {{
105
+ "type": "mcq",
106
+ "question": "What is 2 + 2?",
107
+ "options": ["1","2","3","4"],
108
+ "correct_answer": "4",
109
+ "explanation": "2 + 2 equals 4"
110
+ }}
111
+
112
+ ----------------------------------------------------
113
+
114
+ TRUE/FALSE QUESTIONS
115
+
116
+ correct_answer must be boolean.
117
+
118
+ Example:
119
+
120
+ {{
121
+ "type": "true_false",
122
+ "question": "The Earth revolves around the Sun.",
123
+ "correct_answer": true,
124
+ "explanation": "Astronomy confirms this."
125
+ }}
126
+
127
+ ----------------------------------------------------
128
+
129
+ SHORT ANSWER QUESTIONS
130
+
131
+ Example:
132
+
133
+ {{
134
+ "type": "short_answer",
135
+ "question": "Define photosynthesis.",
136
+ "answer": "Process where plants convert light into chemical energy",
137
+ "explanation": "Occurs in chloroplasts using sunlight"
138
+ }}
139
+
140
+ ----------------------------------------------------
141
+
142
+ ESSAY QUESTIONS
143
+
144
+ Example:
145
+
146
+ {{
147
+ "type": "essay",
148
+ "question": "Explain Newton's First Law.",
149
+ "answer": "Newton's First Law states that an object will remain at rest or continue moving in a straight line at constant velocity unless acted upon by an external force. This property is called inertia. For example, a book on a table stays at rest until someone pushes it, and a moving car continues moving until friction or braking stops it.",
150
+ "answer_guidelines": "Describe inertia and provide examples"
151
+ }}
152
+
153
+ ----------------------------------------------------
154
+
155
+ CODE QUESTIONS
156
+
157
+ Rules:
158
+
159
+ starter_code must be either a string OR null.
160
+ Never output the string "None".
161
+
162
+ Example:
163
+
164
+ {{
165
+ "type": "code",
166
+ "question": "Write a Python function to compute factorial.",
167
+ "language": "c",
168
+ "starter_code": "def factorial(n):",
169
+ "solution": "def factorial(n): return 1 if n<=1 else n*factorial(n-1)",
170
+ "explanation": "Uses recursion"
171
+ }}
172
+
173
+ ----------------------------------------------------
174
+ IMPORTANT RESTRICTIONS
175
+ ----------------------------------------------------
176
+
177
+ Do NOT output:
178
+
179
+ LaTeX
180
+ math formulas
181
+ markdown
182
+ additional fields
183
+
184
+ Use plain text only.
185
+
186
+ ----------------------------------------------------
187
+ FINAL JSON STRUCTURE
188
+ ----------------------------------------------------
189
+
190
+ {{
191
+ "exam_id": "{request.exam_id}",
192
+ "difficulty": "{request.difficulty.value}",
193
+ "total_questions": {request.total_questions},
194
+ "expected_distribution": {json.dumps(distribution)},
195
+ "questions": []
196
+ }}
197
+
198
+ Fill the questions array with the generated questions.
199
+
200
+ ----------------------------------------------------
201
+
202
+ Return ONLY the JSON object.
203
+ """
204
+
205
+ def build_exam_evaluation_prompt(self,request: ExamGenerationRequest,exam: ExamResponse) -> str:
206
+
207
+ exam_json = exam.model_dump_json()
208
+
209
+ return f"""
210
+ You are an exam quality evaluator.
211
+
212
+ --------------------------------
213
+ OUTPUT RULES
214
+ --------------------------------
215
+ 1. Output MUST be valid JSON.
216
+ 2. Do NOT include markdown.
217
+ 3. Do NOT include reasoning outside JSON.
218
+ 4. Output ONLY the JSON object.
219
+ 5. JSON must start with {{ and end with }}.
220
+
221
+ --------------------------------
222
+ SCORING RANGE
223
+ --------------------------------
224
+ 0 to {self.MAX_SCORE}
225
+
226
+ --------------------------------
227
+ EVALUATION CRITERIA
228
+ --------------------------------
229
+ 1. Relevance of questions to the topics
230
+ 2. Correct distribution of question types
231
+ 3. Clarity and wording of questions
232
+ 4. Difficulty consistency
233
+ 5. Correctness of answers
234
+
235
+ --------------------------------
236
+ EXAM TO EVALUATE
237
+ --------------------------------
238
+ {exam_json}
239
+
240
+ --------------------------------
241
+ OUTPUT FORMAT
242
+ --------------------------------
243
+
244
+ {{
245
+ "overall_score": integer between 0 and {self.MAX_SCORE},
246
+ "feedback": "short explanation of issues if any"
247
+ }}
248
+
249
+ Return ONLY JSON.
250
+ """
indexing/indexingController.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stores.llm.LLMProviderFactory import LLMProviderFactory
2
+ from stores.vector_store.Qdrant import QdrantStore
3
+
4
+ from ingestion.loaders.File_loader import load_file
5
+ from ingestion.chunkers.recursive_chunker import recursive_chunk
6
+ from ingestion.pdf_outline import extract_pdf_outline, build_page_bookmark_map , recursive_chunk_with_pages
7
+ from ingestion.loaders.pdf_loader import load_pdf_with_pages
8
+
9
+ from config import get_settings
10
+
11
+ import os
12
+ from qdrant_client import QdrantClient , models
13
+
14
+ class IndexingController:
15
+ def __init__(self):
16
+ config = get_settings()
17
+ self.factory = LLMProviderFactory(config)
18
+ self.embedder = self.factory.create(config.EMBEDDING_BACKEND)
19
+ self.embedder.set_embedding_model(config.EMBEDDING_MODEL_ID, config.EMBEDDING_MODEL_SIZE)
20
+ if config.QDRANT_TYPE == "cloud":
21
+ self.vector_store_client = QdrantClient(url=config.QDRANT_DOCKER_URL,api_key=config.QDRANT_API_KEY,timeout=120)
22
+ elif config.QDRANT_TYPE == "docker":
23
+ self.vector_store_client = QdrantClient(url=config.QDRANT_DOCKER_URL,timeout=120)
24
+ elif config.QDRANT_TYPE == "local":
25
+ self.vector_store_client = QdrantClient(path="data/qdrant",prefer_grpc=False,timeout=120)
26
+
27
+ string_fields = ["metadata.username", "metadata.source", "metadata.course","metadata.bookmark_path"]
28
+
29
+ if not self.vector_store_client.collection_exists(collection_name=get_settings().QDRANT_COLLECTION):
30
+ # 2. Create the collection if it doesn't
31
+ self.vector_store_client.create_collection(
32
+ collection_name=get_settings().QDRANT_COLLECTION,
33
+ vectors_config=models.VectorParams(
34
+ size=get_settings().EMBEDDING_MODEL_SIZE,
35
+ distance=models.Distance.COSINE
36
+ ),
37
+ )
38
+
39
+ for field in string_fields:
40
+ self.vector_store_client.create_payload_index(
41
+ collection_name=get_settings().QDRANT_COLLECTION,
42
+ field_name=field,
43
+ field_schema=models.KeywordIndexParams(
44
+ type=models.KeywordIndexType.KEYWORD
45
+ )
46
+ )
47
+
48
+ self.vector_store= QdrantStore(self.vector_store_client,config.QDRANT_COLLECTION, config.EMBEDDING_MODEL_SIZE)
49
+
50
+ def embed_chunks(self, chunks):
51
+ return self.embedder.embed_text_batch(chunks)
52
+
53
+ def process_file(self,file_path, original_filename, username=None, course=None):
54
+ file_name = os.path.basename(file_path)
55
+ ext = os.path.splitext(file_path)[1].lower()
56
+
57
+ bookmark_map = {}
58
+
59
+ if ext == ".pdf":
60
+ outline , total_pages= extract_pdf_outline(file_path)
61
+ bookmark_map = build_page_bookmark_map(outline , total_pages)
62
+
63
+ pages = load_pdf_with_pages(file_path)
64
+ chunks = recursive_chunk_with_pages(pages)
65
+
66
+ else:
67
+ text = load_file(file_path)
68
+ if isinstance(text, list):
69
+ text = " ".join([doc.page_content for doc in text])
70
+ chunks_text = recursive_chunk(text)
71
+ chunks = [{"text": c, "page": None} for c in chunks_text]
72
+
73
+ embeddings = self.embed_chunks([c["text"] for c in chunks])
74
+
75
+ valid_embs = []
76
+ valid_payloads = []
77
+
78
+ for idx, (chunk_obj, emb) in enumerate(zip(chunks, embeddings)):
79
+ if emb is not None:
80
+ page = chunk_obj["page"]
81
+ bookmark_path = bookmark_map.get(page, [])
82
+
83
+ valid_embs.append(emb)
84
+ valid_payloads.append({
85
+ "content": chunk_obj["text"],
86
+ "metadata": {
87
+ "source": original_filename,
88
+ "chunk_index": idx,
89
+ "total_chunks": len(chunks),
90
+ "username": username,
91
+ "course": course,
92
+ "page": page,
93
+ "bookmark_path": bookmark_path,
94
+ }
95
+ }
96
+ )
97
+ print(f"[DEBUG] Prepared payload for chunk {idx}: page={page}, bookmark_path={bookmark_path}")
98
+
99
+ self.vector_store.upsert_embeddings(
100
+ self.vector_store_client,
101
+ get_settings().QDRANT_COLLECTION,
102
+ valid_embs,
103
+ valid_payloads
104
+ )
105
+ print(f"[INFO] Stored {len(valid_embs)} embeddings for file '{file_name}'.")
106
+
107
+ return {
108
+ "num_chunks": len(chunks),
109
+ "chunks": chunks,
110
+ "embeddings": embeddings
111
+ }
ingestion/chunkers/__init__.py ADDED
File without changes
ingestion/chunkers/fixed_chunker.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_text_splitters import CharacterTextSplitter
2
+ from config import get_settings
3
+
4
+ def fixed_chunk(text):
5
+ splitter = CharacterTextSplitter(
6
+ chunk_size=get_settings().CHUNK_SIZE,
7
+ chunk_overlap=get_settings().CHUNK_OVERLAP
8
+ )
9
+ chunks = splitter.split_text(text)
10
+ return chunks
ingestion/chunkers/recursive_chunker.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
2
+ from config import get_settings
3
+
4
+ def recursive_chunk(text):
5
+ splitter = RecursiveCharacterTextSplitter(
6
+ chunk_size=get_settings().CHUNK_SIZE,
7
+ chunk_overlap=get_settings().CHUNK_OVERLAP,
8
+ )
9
+ return splitter.split_text(text)
ingestion/loaders/File_loader.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import get_settings
2
+ import os
3
+
4
+ def get_file_extension(file_id: str):
5
+ return os.path.splitext(file_id)[-1]
6
+
7
+
8
+ def load_file(file_path: str):
9
+ if get_settings().CustomLoaders==True:
10
+ from ingestion.loaders.pdf_loader import load_pdf
11
+ from ingestion.loaders.txt_loader import load_txt
12
+ from ingestion.loaders.md_loader import load_md
13
+ from ingestion.loaders.docx_loader import load_docx
14
+
15
+
16
+ #Dispatcher
17
+
18
+ ext = os.path.splitext(file_path)[1].lower()
19
+
20
+ if ext == ".pdf":
21
+ docs = load_pdf(file_path)
22
+ elif ext == ".docx":
23
+ docs = load_docx(file_path)
24
+ elif ext == ".md":
25
+ docs = load_md(file_path)
26
+ elif ext == ".txt":
27
+ docs = load_txt(file_path)
28
+ else:
29
+ print(f"Unsupported file type: {ext}")
30
+ return []
31
+
32
+ # Return list of Document objects as-is
33
+ return docs
34
+
35
+
36
+ elif get_settings().CustomLoaders==False:
37
+
38
+ from langchain_community.document_loaders import (
39
+ TextLoader,
40
+ Docx2txtLoader,
41
+ UnstructuredMarkdownLoader,
42
+ PyMuPDFLoader,
43
+ )
44
+
45
+
46
+ extension = get_file_extension(file_path)
47
+
48
+ if extension == ".txt":
49
+ return TextLoader(file_path, encoding="utf8").load()
50
+ elif extension == ".docx":
51
+ return Docx2txtLoader(file_path).load()
52
+ elif extension == ".md":
53
+ return UnstructuredMarkdownLoader(file_path).load()
54
+ elif extension in [".pdf"]:
55
+ return PyMuPDFLoader(file_path).load()
56
+ else:
57
+ raise ValueError(f"Unsupported file extension: {extension}")
ingestion/loaders/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
ingestion/loaders/docx_loader.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from langchain_core.documents import Document
4
+ from docx import Document as DocxDocument
5
+ from docx.oxml.table import CT_Tbl
6
+ from docx.oxml.text.paragraph import CT_P
7
+ from ingestion.loaders.normalization import normalize_text
8
+
9
+ def table_to_text(table) -> str:
10
+ """Convert DOCX table to plain, readable text without numeric headers."""
11
+ data = []
12
+ try:
13
+ for row in table.rows:
14
+ row_data = [normalize_text(cell.text) for cell in row.cells]
15
+ if any(row_data): # skip empty rows
16
+ data.append(row_data)
17
+
18
+ if not data:
19
+ return ""
20
+
21
+ # Format as a readable markdown-like table instead of CSV with numbers
22
+ return "\n".join([" | ".join(row) for row in data])
23
+
24
+ except Exception as e:
25
+ print(f"Error converting table to text: {e}")
26
+ return ""
27
+
28
+
29
+
30
+
31
+ def load_docx(file_path: str) -> List[Document]:
32
+ """Load DOCX file safely, preserving tables and skipping corrupted sections."""
33
+ docs = []
34
+
35
+ if not os.path.exists(file_path):
36
+ print(f"File not found: {file_path}")
37
+ return []
38
+
39
+ try:
40
+ doc = DocxDocument(file_path)
41
+ except Exception as e:
42
+ print(f"Failed to open DOCX ({file_path}): {e}")
43
+ return []
44
+
45
+ try:
46
+ body_elements = list(doc.element.body)
47
+ paragraph_iter = iter(doc.paragraphs)
48
+ table_iter = iter(doc.tables)
49
+
50
+ for element in body_elements:
51
+ if isinstance(element, CT_P):
52
+ try:
53
+ para = next(paragraph_iter)
54
+ cleaned = normalize_text(para.text)
55
+ if cleaned:
56
+ docs.append(
57
+ Document(
58
+ page_content=cleaned,
59
+ metadata={"source": file_path, "type": "text"},
60
+ )
61
+ )
62
+
63
+ except StopIteration:
64
+ continue
65
+ except Exception as e:
66
+ print(f"Error reading paragraph: {e}")
67
+ continue
68
+ elif isinstance(element, CT_Tbl):
69
+ try:
70
+ table = next(table_iter)
71
+ table_text = table_to_text(table)
72
+ if table_text:
73
+ docs.append(
74
+ Document(
75
+ page_content=table_text,
76
+ metadata={"source": file_path, "type": "table"},
77
+ )
78
+ )
79
+ except StopIteration:
80
+ continue
81
+ except Exception as e:
82
+ print(f"Error reading table: {e}")
83
+ continue
84
+
85
+ except Exception as e:
86
+ print(f"[WARN] Error processing DOCX ({file_path}): {e}")
87
+ return []
88
+
89
+ return docs
ingestion/loaders/md_loader.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ from typing import List
4
+ from langchain_core.documents import Document
5
+ from ingestion.loaders.normalization import normalize_text
6
+
7
+
8
+ def load_md(file_path: str) -> List[Document]:
9
+ """Load Markdown safely, preserving inline tables and skipping unreadable sections."""
10
+ if not os.path.exists(file_path):
11
+ print(f"File not found: {file_path}")
12
+ return []
13
+
14
+ text = ""
15
+ try:
16
+ with open(file_path, "r", encoding="utf-8") as f:
17
+ text = f.read()
18
+ except UnicodeDecodeError:
19
+ try:
20
+ with open(file_path, "r", encoding="latin-1") as f:
21
+ text = f.read()
22
+ except Exception as e:
23
+ print(f"Failed to read Markdown file ({file_path}): {e}")
24
+ return []
25
+ except Exception as e:
26
+ print(f"Could not open Markdown file ({file_path}): {e}")
27
+ return []
28
+
29
+ docs = []
30
+ try:
31
+ # Split into segments alternating between text and tables
32
+ parts = re.split(r"((?:\|.*\|\n)+)", text)
33
+ for part in parts:
34
+ if not part.strip():
35
+ continue
36
+
37
+ # Detect if segment is a table
38
+ content_type = "table" if re.match(r"(?:\|.*\|\n)+", part) else "text"
39
+
40
+ # Clean markdown formatting but keep structure
41
+ cleaned = normalize_text(re.sub(r'(```.*?```|`.*?`|\*\*|__|#)', '', part, flags=re.DOTALL))
42
+ if cleaned:
43
+ docs.append(Document(page_content=cleaned, metadata={"source": file_path, "type": content_type}))
44
+ except Exception as e:
45
+ print(f"Error parsing Markdown file ({file_path}): {e}")
46
+ return []
47
+
48
+ return docs
ingestion/loaders/normalization.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def normalize_text(text: str) -> str:
5
+ """Clean and normalize extracted text from any format (PDF/DOCX/MD/TXT)."""
6
+ if not text:
7
+ return ""
8
+
9
+ # Replace common PDF CID artifacts like (cid:1234)
10
+ text = re.sub(r'\(cid:\d+\)', '', text)
11
+
12
+ # Replace newlines/tabs with spaces
13
+ text = text.replace('\n', ' ').replace('\t', ' ')
14
+
15
+ # Remove emojis and pictographs
16
+ emoji_pattern = re.compile(
17
+ "["
18
+ "\U0001F600-\U0001F64F" # emoticons
19
+ "\U0001F300-\U0001F5FF" # symbols & pictographs
20
+ "\U0001F680-\U0001F6FF" # transport & map
21
+ "\U0001F1E0-\U0001F1FF" # flags
22
+ "\U00002500-\U00002BEF"
23
+ "\U00002700-\U000027BF"
24
+ "\U0001F900-\U0001F9FF"
25
+ "\U0001FA70-\U0001FAFF"
26
+ "\U00002600-\U000026FF"
27
+ "\U00002B00-\U00002BFF"
28
+ "]+", flags=re.UNICODE
29
+ )
30
+ text = emoji_pattern.sub("", text)
31
+
32
+ # Collapse multiple spaces
33
+ text = re.sub(r'\s+', ' ', text)
34
+
35
+ return text.strip()
ingestion/loaders/pdf_loader.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langchain_core.documents import Document
3
+ import pdfplumber
4
+ from ingestion.loaders.normalization import normalize_text
5
+
6
+ def load_pdf(file_path: str):
7
+ documents = []
8
+ # Check if file exists
9
+ if not os.path.exists(file_path):
10
+ raise FileNotFoundError(f"File not found: {file_path}")
11
+
12
+ try:
13
+ with pdfplumber.open(file_path) as pdf:
14
+ for page_num, page in enumerate(pdf.pages, start=1):
15
+ try:
16
+ text = page.extract_text() or ""
17
+ text = normalize_text(text)
18
+ tables = page.extract_tables() or []
19
+
20
+ # Reconstruct page text with tables preserved in order
21
+ page_content = text.strip()
22
+ for t_idx, table in enumerate(tables, start=1):
23
+ table_text = "\n".join(
24
+ ["\t".join(cell if cell else "" for cell in row) for row in table]
25
+ )
26
+ table_text = normalize_text(table_text)
27
+ page_content += f"\n\n=== Table {t_idx} (Page {page_num}) ===\n{table_text}"
28
+
29
+ # Append as LangChain Document
30
+ documents.append(
31
+ Document(
32
+ page_content=page_content,
33
+ metadata={
34
+ "source": os.path.basename(file_path),
35
+ "page_number": page_num,
36
+ },
37
+ )
38
+ )
39
+ except Exception as e:
40
+ print(f"Error extracting page {page_num}: {e}")
41
+ continue # Skip corrupted pages, process others
42
+
43
+ except Exception as e:
44
+ print(f"Failed to open or read PDF file: {file_path}")
45
+ print(f"Error: {e}")
46
+ return [] # Return empty list instead of crashing
47
+
48
+ return documents
49
+
50
+
51
+
52
+
53
+
54
+ def load_pdf_with_pages(file_path: str):
55
+ import fitz
56
+ doc = fitz.open(file_path)
57
+ pages = []
58
+
59
+ for i, page in enumerate(doc):
60
+ pages.append({
61
+ "page": i + 1,
62
+ "text": page.get_text()
63
+ })
64
+
65
+ return pages
66
+
ingestion/loaders/txt_loader.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ from langchain_core.documents import Document
4
+ from ingestion.loaders.normalization import normalize_text
5
+
6
+ def load_txt(file_path: str) -> List[Document]:
7
+ """Load plain text file safely, handling encoding issues."""
8
+ docs = []
9
+
10
+ if not os.path.exists(file_path):
11
+ print(f"File not found: {file_path}")
12
+ return docs
13
+
14
+ text = ""
15
+ try:
16
+ with open(file_path, "r", encoding="utf-8") as f:
17
+ text = f.read()
18
+ except UnicodeDecodeError:
19
+ try:
20
+ with open(file_path, "r", encoding="latin-1") as f:
21
+ text = f.read()
22
+ except Exception as e:
23
+ print(f"Failed to read text file ({file_path}): {e}")
24
+ return docs
25
+ except Exception as e:
26
+ print(f"Could not open file ({file_path}): {e}")
27
+ return docs
28
+
29
+ try:
30
+ cleaned = normalize_text(text)
31
+ if cleaned:
32
+ docs.append(
33
+ Document(page_content=cleaned, metadata={"source": file_path, "type": "text"})
34
+ )
35
+ except Exception as e:
36
+ print(f"Error processing text file ({file_path}): {e}")
37
+
38
+ return docs
ingestion/pdf_outline.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fitz
2
+ from config import get_settings
3
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
4
+
5
+ def extract_pdf_outline(pdf_path: str):
6
+ doc = fitz.open(pdf_path)
7
+ toc = doc.get_toc(simple=False)
8
+ total_pages = doc.page_count
9
+
10
+ outline = []
11
+ stack = []
12
+ for level, title, page, *_ in toc:
13
+ while stack and stack[-1]["level"] >= level:
14
+ stack.pop()
15
+ node = {"level": level, "title": title, "page": page, "children": []}
16
+ if stack:
17
+ stack[-1]["children"].append(node)
18
+ else:
19
+ outline.append(node)
20
+ stack.append(node)
21
+
22
+ doc.close()
23
+ return outline , total_pages
24
+
25
+ def build_page_bookmark_map(outline_tree, total_pages: int):
26
+ explicit_map = {}
27
+
28
+ def walk(node, path):
29
+ current_path = path + [node["title"]]
30
+ explicit_map[node["page"]] = current_path
31
+ for child in node["children"]:
32
+ walk(child, current_path)
33
+
34
+ for root in outline_tree:
35
+ walk(root, [])
36
+
37
+ page_map = {}
38
+ last_known_path = []
39
+
40
+ for page_num in range(1, total_pages + 1):
41
+ if page_num in explicit_map:
42
+ last_known_path = explicit_map[page_num]
43
+ page_map[page_num] = last_known_path # carries forward last bookmark
44
+
45
+ return page_map
46
+
47
+ def recursive_chunk_with_pages(pages):
48
+ splitter = RecursiveCharacterTextSplitter(
49
+ chunk_size=get_settings().CHUNK_SIZE,
50
+ chunk_overlap=get_settings().CHUNK_OVERLAP,
51
+ )
52
+
53
+ chunks = []
54
+ for p in pages:
55
+ page_chunks = splitter.split_text(p["text"])
56
+ for c in page_chunks:
57
+ chunks.append({
58
+ "text": c,
59
+ "page": p["page"]
60
+ })
61
+
62
+ return chunks
main.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from routes.base import base_router
4
+ from routes.assisstant_rag import assisstant_router
5
+ from routes.exam_router import exam_router
6
+ from routes.exam_grading_router import grading_router
7
+
8
+ app = FastAPI()
9
+ app.add_middleware(
10
+ CORSMiddleware,
11
+ allow_origins=["*"],
12
+ allow_credentials=True,
13
+ allow_methods=["*"],
14
+ allow_headers=["*"],
15
+ )
16
+
17
+ app.include_router(base_router)
18
+ app.include_router(assisstant_router)
19
+ app.include_router(exam_router)
20
+ app.include_router(grading_router)
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.120.0
2
+ uvicorn==0.38.0
3
+ python-dotenv==1.2.1
4
+ pdfplumber==0.11.7
5
+ python-docx==1.2.0
6
+ pandas==2.3.3
7
+ langchain==1.0.2
8
+ unstructured==0.18.15
9
+ PyMuPDF==1.26.5
10
+ docx2txt==0.9
11
+ Markdown==3.9
12
+ python-multipart==0.0.20
13
+ cohere==5.5.8
14
+ openai==1.35.13
15
+ qdrant-client== 1.16.1
16
+ httpx==0.28.1
17
+ redis==7.2.0
18
+ celery==5.6.2
19
+ json_repair==0.58.5
routes/__init__.py ADDED
File without changes
routes/assisstant_rag.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter , UploadFile, File
2
+ from routes.schemas.Requests_Models import ChatRequest
3
+ from generation.AssistantRagGenerator import AssistantRagGen
4
+ from indexing.indexingController import IndexingController
5
+ from uuid import uuid4
6
+ from worker.tasks import process_file_task
7
+ from celery.result import AsyncResult
8
+ from celery_app import celery_app
9
+
10
+ assisstant_router = APIRouter(tags=["assistant_rag"])
11
+
12
+ @assisstant_router.get("/jobs/{job_id}")
13
+ def get_job_status(job_id: str):
14
+ result = AsyncResult(job_id, app=celery_app)
15
+ if result.state == "PENDING":
16
+ return {"job_id": job_id,"state": result.state,"message": "Job is waiting in queue",}
17
+
18
+ if result.state == "STARTED":
19
+ return {
20
+ "job_id": job_id,
21
+ "state": result.state,
22
+ "message": "Job is currently processing",
23
+ }
24
+
25
+ if result.state == "SUCCESS":
26
+ return {
27
+ "job_id": job_id,
28
+ "state": result.state,
29
+ "result": result.result,
30
+ }
31
+
32
+ if result.state == "FAILURE":
33
+ return {
34
+ "job_id": job_id,
35
+ "state": result.state,
36
+ "error": str(result.result),
37
+ }
38
+
39
+ return {
40
+ "job_id": job_id,
41
+ "state": result.state,
42
+ }
43
+
44
+ @assisstant_router.post("/process-file")
45
+ async def process_file_endpoint(course: str , username: str , file: UploadFile = File(...)):
46
+ job_id = uuid4().hex
47
+ temp_path = f"./temp_{job_id}_{file.filename}"
48
+ with open(temp_path, "wb") as f:
49
+ f.write(await file.read())
50
+ task = process_file_task.delay(temp_path, file.filename, username, course)
51
+ return {
52
+ "job_id": task.id,
53
+ "filename": file.filename,
54
+ "status": "queued",
55
+ }
56
+
57
+ @assisstant_router.post("/chat/complete")
58
+ async def chat_complete_endpoint(request: ChatRequest):
59
+ indexing_controller = IndexingController()
60
+ rag_gen = AssistantRagGen()
61
+ user_query = request.prompt if request.prompt else "no question provided"
62
+ route = rag_gen.robust_router({"question": user_query})
63
+
64
+ results = []
65
+ context_text = ""
66
+ filters = []
67
+
68
+ # Kda Kda pdf :)
69
+ if request.source_file or request.bookmark:
70
+ if request.bookmark and not request.source_file:
71
+ request.bookmark=None
72
+ route = "pdf_query"
73
+
74
+ if route == "user_info":
75
+ if request.role == "instructor" or request.role == "admin":
76
+ context_text = (
77
+ f"User Profile Info: {request.user_info.model_dump()}\n"
78
+ f"Role: {request.role}\n"
79
+ f"Username: {request.username}"
80
+ )
81
+ elif request.role == "student":
82
+ request.user_info=request.user_info.copy(update={"instructor_owned_files": None})
83
+ context_text = (
84
+ f"User Profile Info: {request.user_info.model_dump()}\n"
85
+ f"Role: {request.role}\n"
86
+ f"Username: {request.username}"
87
+ )
88
+
89
+ elif route == "site_query":
90
+ filters = [
91
+ {"field": "course", "op": "eq", "value": "Instructions", "clause": "must"},
92
+ {"field": "username", "op": "eq", "value": "ADMIN", "clause": "must"}
93
+ ]
94
+ embedding = indexing_controller.embedder.embed_text(user_query)
95
+ results = indexing_controller.vector_store.query_qdrant(
96
+ filters=filters,
97
+ embedding=embedding,
98
+ top_k=request.top_k
99
+ )
100
+
101
+ elif route == "pdf_query":
102
+ if request.role == "student":
103
+ enrolled = request.user_info.courses or []
104
+ print(f"[DEBUG] Student {request.username} is enrolled in courses: {enrolled}")
105
+ filters.append({"field": "course", "op": "in", "value": enrolled, "clause": "must"})
106
+
107
+ elif request.role == "instructor":
108
+ owned = request.user_info.courses
109
+ # if owned == []:
110
+ # owned = indexing_controller.vector_store.all_user_files_bookmarks(request.username)
111
+ # owned = owned.keys()
112
+ print(f"[DEBUG] Instructor {request.username} owns courses/files: {owned}")
113
+ filters.append({"field": "course", "op": "in", "value": owned, "clause": "must"})
114
+
115
+ if request.source_file:
116
+ filters.append({"field": "source", "op": "eq", "value": request.source_file, "clause": "must"})
117
+
118
+ if request.bookmark:
119
+ filters.append({"field": "bookmark_path", "op": "text", "value": request.bookmark, "clause": "must"})
120
+
121
+ embedding = indexing_controller.embedder.embed_text(user_query)
122
+ results = indexing_controller.vector_store.query_qdrant(
123
+ filters=filters,
124
+ embedding=embedding,
125
+ top_k=request.top_k
126
+ )
127
+
128
+ if not context_text and results:
129
+ context_text = "\n\n".join([r["content"] for r in results if r.get("content")])
130
+
131
+ history_str = "\n".join(
132
+ f"Human: {turn.Human_msg}\nAssistant: {turn.LLM_response}"
133
+ for turn in request.history
134
+ ) if request.history else "None"
135
+
136
+ if route == "user_info":
137
+ final_prompt = rag_gen.build_user_info_prompt(
138
+ question=user_query,
139
+ conversation_history=history_str,
140
+ User_Info=str(request.user_info.model_dump()),
141
+ )
142
+ elif route == "site_query":
143
+ final_prompt = rag_gen.build_site_query_prompt(
144
+ question=user_query,
145
+ context=context_text,
146
+ conversation_history=history_str
147
+ )
148
+ else:
149
+ final_prompt = rag_gen.build_unified_prompt(
150
+ context=context_text,
151
+ question=user_query,
152
+ conversation_history=history_str,
153
+ User_Info=str(request.user_info.model_dump()),
154
+ )
155
+
156
+ llm_response = rag_gen.generator.generate_text(prompt=final_prompt)
157
+
158
+ return {
159
+ "session_id": request.session_id, # Return as is
160
+ "route": route,
161
+ "query": user_query,
162
+ "history": request.history, # Return as is
163
+ "results": results,
164
+ "LLM_answer": llm_response,
165
+ }
routes/base.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter , Depends
2
+ from config import get_settings
3
+ from indexing.indexingController import IndexingController
4
+
5
+ base_router = APIRouter(tags=["base"])
6
+
7
+
8
+ @base_router.get("/health")
9
+ async def health_check(settings = Depends(get_settings)):
10
+ return {"status": "ok", "app_name": settings}
11
+
12
+ # @base_router.post("/all_docs")
13
+ # async def get_all_docs():
14
+ # indexing_controller = IndexingController()
15
+ # all_docs = indexing_controller.vector_store.get_all_documents()
16
+ # return {
17
+ # "total_docs": len(all_docs),
18
+ # "documents": all_docs
19
+ # }
20
+
21
+ @base_router.get("/all_files")
22
+ async def get__files():
23
+ indexing_controller = IndexingController()
24
+ all_files = indexing_controller.vector_store.get_all_files()
25
+ return {
26
+ "total_files": len(all_files),
27
+ "files": all_files,}
28
+
29
+
30
+ @base_router.get("/remove_file")
31
+ async def remove_file(filename: str,username: str ,course: str):
32
+ indexing_controller = IndexingController()
33
+ result = indexing_controller.vector_store.remove_points_by_file(filename,username,course)
34
+ return {
35
+ "status": "success" if result else "failure",
36
+ "message": f"File '{filename}' removed." if result else f"File '{filename}' not found."
37
+ }
38
+
39
+ @base_router.get("/user/docs")
40
+ async def get_user_docs(username: str):
41
+ indexing_controller = IndexingController()
42
+ user_docs = indexing_controller.vector_store.all_user_files_bookmarks(username)
43
+ return {
44
+ "total_docs": len(user_docs),
45
+ "documents": user_docs}
routes/exam_grading_router.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
+ from pydantic import BaseModel
3
+ import re
4
+
5
+ from generation.ExamAnswer import ExamGradingService, grade_exam_task
6
+ from generation.answer_models import ExamSubmission, ExamResult
7
+ from routes.schemas.Exam_Models import ExamResponse
8
+
9
+ grading_router = APIRouter(prefix="/exam/grading", tags=["exam_grading"])
10
+
11
+ class GradingResponse(BaseModel):
12
+ job_id: str
13
+ exam_id: str
14
+ student_id: str
15
+ status: str
16
+
17
+ class GradingRequest(BaseModel):
18
+ submission: ExamSubmission
19
+ exam: ExamResponse
20
+
21
+
22
+ def normalize_text(text: str) -> str:
23
+ if not text:
24
+ return ""
25
+ text = re.sub(r'[^\w\s]', '', text)
26
+ text = re.sub(r'\s+', ' ', text)
27
+ return text.strip().lower()
28
+
29
+ @grading_router.post("/submit", response_model=GradingResponse)
30
+ async def submit_exam(request: GradingRequest):
31
+ submission_dict = request.submission.model_dump()
32
+ exam_questions_map = {}
33
+
34
+ for q in request.exam.questions:
35
+ normalized_q = normalize_text(q.question)
36
+ exam_questions_map[normalized_q] = q
37
+
38
+ for answer in submission_dict["answers"]:
39
+ question_text = answer["question_text"]
40
+ question_type = answer["question_type"]
41
+ normalized_answer_text = normalize_text(question_text)
42
+
43
+
44
+ correct_answer = None
45
+ if normalized_answer_text in exam_questions_map:
46
+ q = exam_questions_map[normalized_answer_text]
47
+
48
+ if question_type == "multiple_choice" and hasattr(q, 'correct_answer'):
49
+ correct_answer = q.correct_answer
50
+ elif question_type == "true_false" and hasattr(q, 'correct_answer'):
51
+ correct_answer = q.correct_answer
52
+ elif question_type == "short_answer" and hasattr(q, 'answer'):
53
+ correct_answer = q.answer
54
+ elif question_type == "code" and hasattr(q, 'solution'):
55
+ correct_answer = q.solution
56
+ elif question_type == "essay":
57
+ if hasattr(q, 'answer_guidelines') and q.answer_guidelines:
58
+ correct_answer = q.answer_guidelines
59
+ elif hasattr(q, 'answer'):
60
+ correct_answer = q.answer
61
+
62
+
63
+ if "metadata" not in answer:
64
+ answer["metadata"] = {}
65
+ answer["metadata"]["correct_answer"] = correct_answer
66
+
67
+
68
+ task = grade_exam_task.delay(submission_dict)
69
+
70
+ return GradingResponse(
71
+ job_id=task.id,
72
+ exam_id=request.submission.exam_id,
73
+ student_id=request.submission.student_id,
74
+ status="queued"
75
+ )
76
+
77
+ @grading_router.post("/grade-sync", response_model=ExamResult)
78
+ async def grade_sync(request: GradingRequest):
79
+ try:
80
+ service = ExamGradingService(use_ai_for_essays=True)
81
+
82
+
83
+ exam_questions_map = {}
84
+ for q in request.exam.questions:
85
+ normalized_q = normalize_text(q.question)
86
+ exam_questions_map[normalized_q] = q
87
+
88
+ for ans in request.submission.answers:
89
+ question_text = ans.question_text
90
+ question_type = ans.question_type
91
+ normalized_answer_text = normalize_text(question_text)
92
+
93
+
94
+ correct_answer = None
95
+ if normalized_answer_text in exam_questions_map:
96
+ q = exam_questions_map[normalized_answer_text]
97
+
98
+
99
+ if question_type == "multiple_choice" and hasattr(q, 'correct_answer'):
100
+ correct_answer = q.correct_answer
101
+ elif question_type == "true_false" and hasattr(q, 'correct_answer'):
102
+ correct_answer = q.correct_answer
103
+ elif question_type == "short_answer" and hasattr(q, 'answer'):
104
+ correct_answer = q.answer
105
+ elif question_type == "code" and hasattr(q, 'solution'):
106
+ correct_answer = q.solution
107
+ elif question_type == "essay":
108
+ if hasattr(q, 'answer_guidelines') and q.answer_guidelines:
109
+ correct_answer = q.answer_guidelines
110
+ elif hasattr(q, 'answer'):
111
+ correct_answer = q.answer
112
+
113
+ if correct_answer is not None:
114
+ if not ans.metadata:
115
+ ans.metadata = {}
116
+ ans.metadata["correct_answer"] = correct_answer
117
+
118
+
119
+ result = service.grade_submission(request.submission)
120
+ return result
121
+ except Exception as e:
122
+ raise HTTPException(status_code=400, detail=str(e))
routes/exam_router.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from routes.schemas.Exam_Models import ExamGenerationRequest
3
+ from worker.tasks import generate_exam_task
4
+
5
+ exam_router = APIRouter(prefix="/exam", tags=["exam"])
6
+
7
+
8
+ @exam_router.post("/create")
9
+ async def process_file_endpoint(request: ExamGenerationRequest):
10
+ task = generate_exam_task.delay(request.model_dump())
11
+ return {
12
+ "job_id": task.id,
13
+ "exam_id": request.exam_id,
14
+ "status": "queued",
15
+ }
routes/schemas/Exam_Models.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, field_validator, model_validator
2
+ from typing import List, Optional, Dict
3
+ from enum import Enum
4
+ from typing import Union
5
+ from typing import Literal
6
+ from pydantic import Field
7
+ from typing import Annotated
8
+
9
+
10
+ class QuestionType(str, Enum):
11
+ MCQ = "mcq"
12
+ TRUE_FALSE = "true_false"
13
+ SHORT_ANSWER = "short_answer"
14
+ ESSAY = "essay"
15
+ CODE = "code"
16
+
17
+ class DifficultyLevel(str, Enum):
18
+ EASY = "easy"
19
+ MEDIUM = "medium"
20
+ HARD = "hard"
21
+
22
+ class Reference(BaseModel):
23
+ filename: str
24
+ bookmarks: Optional[List[str]] = None
25
+
26
+ class ExamGenerationRequest(BaseModel):
27
+ username: str
28
+ course: str
29
+ exam_id: str
30
+ total_questions: int
31
+ topics: List[str]
32
+ references: Optional[List[Reference]] = None
33
+ difficulty: Optional[DifficultyLevel] = DifficultyLevel.MEDIUM
34
+ include_answer_key: Optional[bool] = True
35
+ question_types_distribution: Dict[QuestionType, int]
36
+ model_config = {"extra": "ignore"}
37
+
38
+ @field_validator("topics")
39
+ @classmethod
40
+ def validate_topics(cls, v):
41
+ if not v:
42
+ raise ValueError("Topics cannot be empty")
43
+ return v
44
+
45
+ @field_validator("question_types_distribution")
46
+ @classmethod
47
+ def validate_positive(cls, v):
48
+ if any(count <= 0 for count in v.values()):
49
+ raise ValueError("All distribution counts must be > 0")
50
+ return v
51
+
52
+ @model_validator(mode="after")
53
+ def validate_sum(self):
54
+ if sum(self.question_types_distribution.values()) != self.total_questions:
55
+ raise ValueError("Distribution must equal total_questions")
56
+ return self
57
+
58
+ class QuestionBase(BaseModel):
59
+ type: QuestionType
60
+ question: str
61
+ model_config = {"extra": "ignore"}
62
+
63
+ class MCQQuestion(QuestionBase):
64
+ type: Literal[QuestionType.MCQ]
65
+ options: List[str]
66
+ correct_answer: str
67
+ explanation: str
68
+
69
+ @model_validator(mode="after")
70
+ def validate_mcq(self):
71
+ if len(self.options) < 2:
72
+ raise ValueError("MCQ must contain at least 2 options")
73
+ if self.correct_answer not in self.options:
74
+ raise ValueError("correct_answer must exist in options")
75
+ return self
76
+
77
+ class TrueFalseQuestion(QuestionBase):
78
+ type: Literal[QuestionType.TRUE_FALSE]
79
+ correct_answer: bool
80
+ explanation: str
81
+
82
+ class ShortAnswerQuestion(QuestionBase):
83
+ type: Literal[QuestionType.SHORT_ANSWER]
84
+ answer: str
85
+ explanation: str
86
+
87
+ class EssayQuestion(QuestionBase):
88
+ type: Literal[QuestionType.ESSAY]
89
+ answer: str
90
+ answer_guidelines: str
91
+
92
+ class CodeQuestion(QuestionBase):
93
+ type: Literal[QuestionType.CODE]
94
+
95
+ starter_code: Optional[str] = Field(
96
+ default=None,
97
+ description="Starter code shown to the student"
98
+ )
99
+
100
+ language: str= "c"
101
+
102
+ solution: str = Field(
103
+ description="Correct full solution code"
104
+ )
105
+
106
+ explanation: str = Field(
107
+ description="Explanation of how the solution works"
108
+ )
109
+
110
+ @field_validator("starter_code", "solution")
111
+ @classmethod
112
+ def normalize_code(cls, v):
113
+ """Convert escaped newlines to real newlines if present."""
114
+ if v:
115
+ return v.replace("\\n", "\n")
116
+ return v
117
+
118
+ QuestionUnion = Annotated[
119
+ Union[
120
+ MCQQuestion,
121
+ TrueFalseQuestion,
122
+ ShortAnswerQuestion,
123
+ EssayQuestion,
124
+ CodeQuestion,
125
+ ],
126
+ Field(discriminator="type"),
127
+ ]
128
+
129
+ class ExamResponse(BaseModel):
130
+ exam_id: str
131
+ difficulty: DifficultyLevel
132
+ total_questions: int
133
+ questions: List[QuestionUnion]
134
+ expected_distribution: Dict[QuestionType, int]
135
+ model_config = {"extra": "ignore"}
136
+
137
+ @model_validator(mode="after")
138
+ def validate_question_count(self):
139
+ if len(self.questions) != self.total_questions:
140
+ raise ValueError(
141
+ f"Expected {self.total_questions} questions, "
142
+ f"but got {len(self.questions)}"
143
+ )
144
+ return self
145
+ @model_validator(mode="after")
146
+ def validate_distribution(self):
147
+
148
+ actual_counts: Dict[QuestionType, int] = {}
149
+
150
+ for q in self.questions:
151
+ actual_counts[q.type] = actual_counts.get(q.type, 0) + 1
152
+
153
+ if set(actual_counts.keys()) != set(self.expected_distribution.keys()):
154
+ raise ValueError("Unexpected question types in exam")
155
+
156
+ for q_type, expected_count in self.expected_distribution.items():
157
+ actual = actual_counts.get(q_type, 0)
158
+
159
+ if actual != expected_count:
160
+ raise ValueError(
161
+ f"Distribution mismatch for {q_type.value}: "
162
+ f"expected {expected_count}, got {actual}"
163
+ )
164
+
165
+ return self
166
+
167
+ class AnswerItem(BaseModel):
168
+ question_index: int
169
+ answer: str
170
+
171
+ class AnswerKey(BaseModel):
172
+ exam_id: str
173
+ answers: List[AnswerItem]
174
+ model_config = {"extra": "ignore"}
175
+
176
+ class EvaluationResult(BaseModel):
177
+ overall_score: int
178
+ feedback: str
179
+ model_config = {"extra": "ignore"}
180
+
routes/schemas/Requests_Models.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ from typing import Optional, List
3
+
4
+ class ConversationTurn(BaseModel):
5
+ Human_msg: str
6
+ LLM_response: str
7
+
8
+ class UserInfoRequest(BaseModel):
9
+ courses: Optional[List[str]] = None
10
+ deadlines: Optional[List[str]] = None
11
+ grades: Optional[List[str]] = None
12
+ instructor_owned_files: Optional[List[str]] = None
13
+ more_info: Optional[str] = None
14
+
15
+ class ChatRequest(BaseModel):
16
+ prompt: Optional[str] = None
17
+ username: str
18
+ session_id: str
19
+ role: str
20
+ top_k: int = 5
21
+ source_file: Optional[str] = None
22
+ bookmark: Optional[str] = None
23
+ history: Optional[List[ConversationTurn]] = None
24
+ user_info: Optional[UserInfoRequest]= None
routes/schemas/__init__.py ADDED
File without changes
stores/llm/LLMEnums.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ class LLMEnums(Enum):
4
+ OPENAI = "OPENAI"
5
+ COHERE = "COHERE"
6
+ OLLAMA = "OLLAMA"
7
+ MISTRAL = "MISTRAL"
8
+ GROQ = "GROQ"
9
+ OPENROUTER = "OPENROUTER"
10
+ HUGGINGFACE = "HUGGINGFACE"
11
+ DEEPSEEK = "DEEPSEEK"
12
+ GEMINI = "GEMINI"
13
+
14
+ class OpenAIEnums(Enum):
15
+ SYSTEM = "system"
16
+ USER = "user"
17
+ ASSISTANT = "assistant"
18
+
19
+ class CoHereEnums(Enum):
20
+ SYSTEM = "SYSTEM"
21
+ USER = "USER"
22
+ ASSISTANT = "CHATBOT"
23
+ DOCUMENT = "search_document"
24
+ QUERY = "search_query"
25
+
26
+ class DocumentTypeEnum(Enum):
27
+ DOCUMENT = "document"
28
+ QUERY = "query"
stores/llm/LLMInterface.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ class LLMInterface(ABC):
4
+
5
+ @abstractmethod
6
+ def set_generation_model(self, model_id: str):
7
+ pass
8
+
9
+ @abstractmethod
10
+ def set_embedding_model(self, model_id: str, embedding_size: int):
11
+ pass
12
+
13
+ @abstractmethod
14
+ def generate_text(self, prompt: str, chat_history: list=[], max_output_tokens: int=None,
15
+ temperature: float = None):
16
+ pass
17
+
18
+ @abstractmethod
19
+ def embed_text_batch(self, texts: list[str], batch_size: int = 32):
20
+ pass
21
+
22
+ @abstractmethod
23
+ def construct_prompt(self, prompt: str, role: str):
24
+ pass
stores/llm/LLMProviderFactory.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .LLMEnums import LLMEnums
2
+ from stores.llm.providers.OpenAIProvider import OpenAIProvider
3
+ from stores.llm.providers.OllamaProvider import OllamaProvider
4
+ from stores.llm.providers.CohereProvider import CohereProvider
5
+ from stores.llm.providers.MistralProvider import MistralProvider
6
+ from stores.llm.providers.GroqProvider import GroqProvider
7
+ from stores.llm.providers.OpenRouterProvider import OpenRouterProvider
8
+ from stores.llm.providers.HuggingFaceProvider import HuggingFaceProvider
9
+ from stores.llm.providers.DeepSeekProvider import DeepSeekProvider
10
+ from stores.llm.providers.GeminiProvider import GeminiProvider
11
+
12
+
13
+ class LLMProviderFactory:
14
+ def __init__(self, config: dict):
15
+ self.config = config
16
+
17
+ def create(self, provider: str):
18
+
19
+ if provider == LLMEnums.OPENAI.value:
20
+ return OpenAIProvider(
21
+ api_key=self.config.OPENAI_API_KEY,
22
+ api_url=self.config.OPENAI_API_URL,
23
+ default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
24
+ default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
25
+ default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
26
+ )
27
+
28
+ if provider == LLMEnums.OLLAMA.value:
29
+ return OllamaProvider(
30
+ url=self.config.OLLAMA_URL,
31
+ api_key=self.config.OLLAMA_API_KEY,
32
+ default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
33
+ default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
34
+ default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
35
+ )
36
+
37
+ if provider == LLMEnums.COHERE.value:
38
+ return CohereProvider(
39
+ api_key=self.config.COHERE_API_KEY,
40
+ default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
41
+ default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
42
+ default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
43
+ )
44
+
45
+ if provider == LLMEnums.MISTRAL.value:
46
+ return MistralProvider(
47
+ api_key=self.config.MISTRAL_API_KEY,
48
+ default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
49
+ default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
50
+ default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
51
+ )
52
+
53
+ if provider == LLMEnums.GROQ.value:
54
+ return GroqProvider(
55
+ api_key=self.config.GROQ_API_KEY,
56
+ default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
57
+ default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
58
+ default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
59
+ )
60
+
61
+ if provider == LLMEnums.OPENROUTER.value:
62
+ return OpenRouterProvider(
63
+ api_key=self.config.OPENROUTER_API_KEY,
64
+ default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
65
+ default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
66
+ default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
67
+ )
68
+
69
+ if provider == LLMEnums.HUGGINGFACE.value:
70
+ return HuggingFaceProvider(
71
+ api_key=self.config.HF_API_KEY,
72
+ default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
73
+ default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
74
+ default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
75
+ )
76
+
77
+ if provider == LLMEnums.DEEPSEEK.value:
78
+ return DeepSeekProvider(
79
+ api_key=self.config.DEEPSEEK_API_KEY,
80
+ default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
81
+ default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
82
+ default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
83
+ )
84
+
85
+ if provider == LLMEnums.GEMINI.value:
86
+ return GeminiProvider(
87
+ api_key=self.config.GEMINI_API_KEY,
88
+ default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS,
89
+ default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS,
90
+ default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE,
91
+ )
92
+
93
+ return None
stores/llm/__init__.py ADDED
File without changes
stores/llm/providers/CohereProvider.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stores.llm.LLMInterface import LLMInterface
2
+ import logging
3
+ import requests
4
+ import re
5
+ import os
6
+ import time
7
+ import math
8
+ class CohereProvider(LLMInterface):
9
+ def __init__(self, url: str = None, model: str = None,
10
+ default_input_max_characters: int = 1000,
11
+ default_generation_max_output_tokens: int = 1000,
12
+ default_generation_temperature: float = 0.1, api_key: str = None):
13
+ self.url = url or "https://api.cohere.com/v2"
14
+ self.api_key = api_key or os.getenv("COHERE_API_KEY")
15
+ self.model = model
16
+ self.generation_model_id = None
17
+
18
+ self.embedding_model = None
19
+ self.embedding_model_id = None
20
+ self.embedding_size = None
21
+
22
+ self.default_input_max_characters = default_input_max_characters
23
+ self.default_generation_max_output_tokens = default_generation_max_output_tokens
24
+ self.default_generation_temperature = default_generation_temperature
25
+ self.logger = logging.getLogger(__name__)
26
+
27
+ def set_generation_model(self, model_id: str):
28
+ if model_id:
29
+ self.model = model_id
30
+
31
+ def set_embedding_model(self, model_id: str, embedding_size: int):
32
+ if model_id:
33
+ self.embedding_model = model_id
34
+ self.embedding_size = embedding_size
35
+ self.embedding_model_id = model_id
36
+
37
+ def process_text(self, text: str):
38
+ if not text:
39
+ return ""
40
+ return str(text).strip()
41
+
42
+ def generate_text(self, prompt: str, chat_history: list = None,
43
+ max_output_tokens: int = None, temperature: float = None):
44
+ try:
45
+ chat_history = chat_history or [] # safe handling
46
+ clean_prompt = self.process_text(prompt)
47
+
48
+ # Build messages list from chat_history + current prompt
49
+ messages = []
50
+ for entry in chat_history:
51
+ messages.append({
52
+ "role": entry.get("role", "user"),
53
+ "content": entry.get("content", "")
54
+ })
55
+ messages.append({"role": "user", "content": clean_prompt})
56
+
57
+ payload = {
58
+ "model": self.model,
59
+ "messages": messages,
60
+ "max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens),
61
+ "temperature": float(temperature or self.default_generation_temperature),
62
+ }
63
+
64
+ url = self.url.rstrip("/") + "/chat"
65
+ headers = {
66
+ "Authorization": f"Bearer {self.api_key}",
67
+ "Content-Type": "application/json",
68
+ }
69
+
70
+ resp = requests.post(url, json=payload, headers=headers, timeout=6000)
71
+ if resp.status_code != 200:
72
+ self.logger.error("Cohere generate failed: %s %s", resp.status_code, resp.text)
73
+ return None
74
+
75
+ data = resp.json()
76
+
77
+ # Extract generated text from Cohere v2 chat response
78
+ generated_text = ""
79
+ try:
80
+ generated_text = data["message"]["content"][0]["text"].strip()
81
+ except (KeyError, IndexError, TypeError):
82
+ self.logger.error("Unexpected Cohere response structure: %s", data)
83
+ return None
84
+
85
+ if not generated_text:
86
+ return None
87
+
88
+ # Mirror the same return shape as OllamaProvider
89
+ usage = data.get("usage", {})
90
+ return {
91
+ "model": data.get("model"),
92
+ "response": generated_text,
93
+ "tokens_generated": usage.get("tokens", {}).get("output_tokens"),
94
+ "total_duration_ms": None, # Cohere does not expose latency in response
95
+ "prompt_eval_tokens": usage.get("tokens", {}).get("input_tokens"),
96
+ }
97
+
98
+ except Exception as e:
99
+ self.logger.exception("Error in CohereProvider.generate_text: %s", e)
100
+ return None
101
+
102
+ def embed_text(self, text: str, document_type: str = None):
103
+ """Return an embedding vector from Cohere."""
104
+ try:
105
+ if not self.embedding_model:
106
+ self.logger.error("Embedding model is not set before calling embed_text()")
107
+ return None
108
+
109
+ clean_text = self.process_text(text)
110
+ print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'")
111
+ if not clean_text:
112
+ return []
113
+
114
+ # Cohere requires an input_type; map document_type or fall back to "search_document"
115
+ input_type = document_type if document_type in (
116
+ "search_document", "search_query", "classification", "clustering"
117
+ ) else "search_document"
118
+
119
+ payload = {
120
+ "model": self.embedding_model,
121
+ "texts": [clean_text],
122
+ "input_type": input_type,
123
+ "embedding_types": ["float"],
124
+ }
125
+
126
+ url = self.url.rstrip("/") + "/embed"
127
+ headers = {
128
+ "Authorization": f"Bearer {self.api_key}",
129
+ "Content-Type": "application/json",
130
+ }
131
+
132
+ resp = requests.post(url, json=payload, headers=headers, timeout=200)
133
+ if resp.status_code != 200:
134
+ print(f"[ERROR] Cohere embedding failed: {resp.status_code} {resp.text}")
135
+ return None
136
+
137
+ data = resp.json()
138
+
139
+ # Cohere v2 returns embeddings under data.embeddings.float
140
+ embedding = None
141
+ try:
142
+ embedding = data["embeddings"]["float"][0]
143
+ except (KeyError, IndexError, TypeError):
144
+ pass
145
+
146
+ # Fallback: older v1-style shape
147
+ if embedding is None:
148
+ try:
149
+ embedding = data["embeddings"][0]
150
+ except (KeyError, IndexError, TypeError):
151
+ pass
152
+
153
+ if embedding is not None:
154
+ print(f"[DEBUG] Embedding length: {len(embedding)}")
155
+ return embedding
156
+
157
+ print("[WARNING] 'embedding' key not found in response JSON")
158
+ return None
159
+
160
+ except Exception as e:
161
+ print(f"[EXCEPTION] Error in CohereProvider.embed_text: {e}")
162
+ return None
163
+
164
+ def construct_prompt(self, prompt: str, role: str):
165
+ return {
166
+ "role": role,
167
+ "content": self.process_text(prompt)
168
+ }
169
+ def embed_text_batch(self, texts: list[str], batch_size: int = 96):
170
+ self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}")
171
+
172
+ if not self.embedding_model:
173
+ self.logger.error("Embedding model not set")
174
+ return None
175
+
176
+ all_embeddings = []
177
+ total_batches = math.ceil(len(texts) / batch_size)
178
+
179
+ url = self.url.rstrip("/") + "/embed"
180
+ headers = {
181
+ "Authorization": f"Bearer {self.api_key}",
182
+ "Content-Type": "application/json",
183
+ }
184
+
185
+ # Cohere free tier: 10 req/min | paid: 100 req/min
186
+ # Adjust MIN_SECONDS_PER_REQUEST to match your plan
187
+ MIN_SECONDS_PER_REQUEST = 0.65 # ~92 req/min (safe under 100/min paid)
188
+ MAX_RETRIES = 5
189
+ BACKOFF_BASE = 10 # seconds — doubles on each retry
190
+
191
+ for batch_idx, i in enumerate(range(0, len(texts), batch_size), start=1):
192
+ time.sleep(6)
193
+ batch = texts[i:i + batch_size]
194
+ clean_batch = [self.process_text(t) for t in batch if t]
195
+
196
+ # ── Progress ────────────────────────────────────────────────────────
197
+ done_texts = min(i + batch_size, len(texts))
198
+ pct = (batch_idx / total_batches) * 100
199
+ bar_filled = int(pct / 5) # 20-char bar
200
+ bar = "█" * bar_filled + "░" * (20 - bar_filled)
201
+ print(
202
+ f"\r[EMBED] [{bar}] {pct:5.1f}% "
203
+ f"batch {batch_idx}/{total_batches} "
204
+ f"({done_texts}/{len(texts)} texts)",
205
+ end="", flush=True
206
+ )
207
+ # ────────────────────────────────────────────────────────────────────
208
+
209
+ payload = {
210
+ "model": self.embedding_model,
211
+ "texts": clean_batch,
212
+ "input_type": "search_document",
213
+ "embedding_types": ["float"],
214
+ }
215
+
216
+ # ── Rate-limited request with exponential back-off ──────────────────
217
+ embeddings = None
218
+ request_start = time.monotonic()
219
+
220
+ for attempt in range(1, MAX_RETRIES + 1):
221
+ resp = requests.post(url, json=payload, headers=headers, timeout=200)
222
+
223
+ if resp.status_code == 200:
224
+ break
225
+
226
+ if resp.status_code == 429:
227
+ retry_after = float(resp.headers.get("Retry-After", BACKOFF_BASE ** attempt))
228
+ print(
229
+ f"\n[RATE LIMIT] batch {batch_idx} — "
230
+ f"attempt {attempt}/{MAX_RETRIES}, "
231
+ f"waiting {retry_after:.1f}s …"
232
+ )
233
+ time.sleep(retry_after)
234
+ continue
235
+
236
+ # Any other non-200 — log and abort
237
+ self.logger.error(
238
+ "Cohere embedding failed (batch %d, attempt %d): %s %s",
239
+ batch_idx, attempt, resp.status_code, resp.text
240
+ )
241
+ return None
242
+
243
+ else:
244
+ # Exhausted all retries on 429
245
+ self.logger.error(
246
+ "Cohere embedding: max retries (%d) exceeded on batch %d",
247
+ MAX_RETRIES, batch_idx
248
+ )
249
+ return None
250
+
251
+ # ── Parse response ──────────────────────────────────────────────────
252
+ data = resp.json()
253
+
254
+ try:
255
+ embeddings = data["embeddings"]["float"] # v2 shape
256
+ except (KeyError, TypeError):
257
+ embeddings = data.get("embeddings") # v1 shape
258
+
259
+ if not embeddings:
260
+ self.logger.error("No embeddings returned from Cohere (batch %d)", batch_idx)
261
+ return None
262
+
263
+ self.logger.debug(f"Received {len(embeddings)} embeddings for batch {batch_idx}")
264
+ all_embeddings.extend(embeddings)
265
+
266
+ # ── Pace requests to stay under rate limit ──────────────────────────
267
+ elapsed = time.monotonic() - request_start
268
+ sleep_for = max(0.0, MIN_SECONDS_PER_REQUEST - elapsed)
269
+ if sleep_for > 0:
270
+ time.sleep(sleep_for)
271
+ # ────────────────────────────────────────────────────────────────────
272
+
273
+ # Final newline after the progress bar
274
+ print(f"\r[EMBED] [{'█' * 20}] 100.0% "
275
+ f"batch {total_batches}/{total_batches} "
276
+ f"({len(texts)}/{len(texts)} texts) ✓")
277
+
278
+ self.logger.info(f"Total embeddings created: {len(all_embeddings)}")
279
+ return all_embeddings
280
+ # def embed_text_batch(self, texts: list[str], batch_size: int = 32):
281
+ # self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}")
282
+
283
+ # if not self.embedding_model:
284
+ # self.logger.error("Embedding model not set")
285
+ # return None
286
+
287
+ # all_embeddings = []
288
+
289
+ # url = self.url.rstrip("/") + "/embed"
290
+ # headers = {
291
+ # "Authorization": f"Bearer {self.api_key}",
292
+ # "Content-Type": "application/json",
293
+ # }
294
+
295
+ # for i in range(0, len(texts), batch_size):
296
+ # batch = texts[i:i + batch_size]
297
+ # clean_batch = [self.process_text(t) for t in batch if t]
298
+
299
+ # print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size}")
300
+
301
+ # payload = {
302
+ # "model": self.embedding_model,
303
+ # "texts": clean_batch,
304
+ # "input_type": "search_document",
305
+ # "embedding_types": ["float"],
306
+ # }
307
+
308
+ # resp = requests.post(url, json=payload, headers=headers, timeout=200)
309
+ # if resp.status_code != 200:
310
+ # self.logger.error("Cohere embedding failed: %s %s", resp.status_code, resp.text)
311
+ # return None
312
+
313
+ # data = resp.json()
314
+
315
+ # # Handle both v2 (embeddings.float) and v1 (embeddings) shapes
316
+ # embeddings = None
317
+ # try:
318
+ # embeddings = data["embeddings"]["float"]
319
+ # except (KeyError, TypeError):
320
+ # embeddings = data.get("embeddings")
321
+
322
+ # if not embeddings:
323
+ # self.logger.error("No embeddings returned from Cohere")
324
+ # return None
325
+
326
+ # self.logger.debug(f"Received {len(embeddings)} embeddings")
327
+ # all_embeddings.extend(embeddings)
328
+
329
+ # self.logger.info(f"Total embeddings created: {len(all_embeddings)}")
330
+ # return all_embeddings
331
+
332
+ def clean_content(self, text: str) -> str:
333
+ text = re.sub(r'\[.*?\]\(.*?\)', '', text)
334
+ text = re.sub(r'\[[^\]]*\]', '', text)
335
+ text = re.sub(r'\n+', '\n', text).strip()
336
+ return text
337
+
338
+ def web_search(self, query: str):
339
+ """Use Cohere's chat endpoint with web-search connector to perform a search."""
340
+ try:
341
+ payload = {
342
+ "model": self.model,
343
+ "messages": [{"role": "user", "content": query}],
344
+ "tools": [{"type": "web_search"}],
345
+ }
346
+
347
+ url = self.url.rstrip("/") + "/chat"
348
+ headers = {
349
+ "Authorization": f"Bearer {self.api_key}",
350
+ "Content-Type": "application/json",
351
+ }
352
+
353
+ resp = requests.post(url, json=payload, headers=headers, timeout=6000)
354
+
355
+ if not resp or resp.status_code != 200:
356
+ return {
357
+ "text": "No relevant external results found.",
358
+ "references": []
359
+ }
360
+
361
+ data = resp.json()
362
+
363
+ combined_text = []
364
+ references = set()
365
+
366
+ # Extract assistant text
367
+ try:
368
+ assistant_text = data["message"]["content"][0]["text"]
369
+ combined_text.append(self.clean_content(assistant_text))
370
+ except (KeyError, IndexError, TypeError):
371
+ pass
372
+
373
+ # Extract citations / source URLs from Cohere's citations block
374
+ for citation in data.get("message", {}).get("citations", []):
375
+ for source in citation.get("sources", []):
376
+ url_val = source.get("url") or source.get("id", "")
377
+ if url_val.startswith("http"):
378
+ references.add(url_val)
379
+
380
+ # Also scan raw text for bare URLs (mirrors Ollama behaviour)
381
+ raw_text = "\n".join(combined_text)
382
+ for found_url in re.findall(r"https?://[^\s)]+", raw_text):
383
+ references.add(found_url)
384
+
385
+ return {
386
+ "text": "\n\n".join(combined_text[:3]),
387
+ "references": list(references)
388
+ }
389
+
390
+ except Exception as e:
391
+ self.logger.error("Cohere web search failed: %s", e)
392
+ return {
393
+ "text": f"Cohere search error: {str(e)}",
394
+ "references": []
395
+ }
stores/llm/providers/DeepSeekProvider.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stores.llm.LLMInterface import LLMInterface
2
+ import logging
3
+ import requests
4
+ import re
5
+ import os
6
+
7
+
8
+ class DeepSeekProvider(LLMInterface):
9
+ def __init__(self, url: str = None, model: str = None,
10
+ default_input_max_characters: int = 1000,
11
+ default_generation_max_output_tokens: int = 1000,
12
+ default_generation_temperature: float = 0.1, api_key: str = None):
13
+ self.url = url or "https://api.deepseek.com/v1"
14
+ self.api_key = api_key or os.getenv("DEEPSEEK_API_KEY")
15
+ self.model = model
16
+ self.generation_model_id = None
17
+
18
+ self.embedding_model = None
19
+ self.embedding_model_id = None
20
+ self.embedding_size = None
21
+
22
+ self.default_input_max_characters = default_input_max_characters
23
+ self.default_generation_max_output_tokens = default_generation_max_output_tokens
24
+ self.default_generation_temperature = default_generation_temperature
25
+ self.logger = logging.getLogger(__name__)
26
+
27
+ def set_generation_model(self, model_id: str):
28
+ if model_id:
29
+ self.model = model_id
30
+
31
+ def set_embedding_model(self, model_id: str, embedding_size: int):
32
+ if model_id:
33
+ self.embedding_model = model_id
34
+ self.embedding_size = embedding_size
35
+ self.embedding_model_id = model_id
36
+
37
+ def process_text(self, text: str):
38
+ if not text:
39
+ return ""
40
+ return str(text).strip()
41
+
42
+ def generate_text(self, prompt: str, chat_history: list = None,
43
+ max_output_tokens: int = None, temperature: float = None):
44
+ try:
45
+ chat_history = chat_history or []
46
+ clean_prompt = self.process_text(prompt)
47
+
48
+ messages = []
49
+ for entry in chat_history:
50
+ messages.append({
51
+ "role": entry.get("role", "user"),
52
+ "content": entry.get("content", "")
53
+ })
54
+ messages.append({"role": "user", "content": clean_prompt})
55
+
56
+ payload = {
57
+ "model": self.model,
58
+ "messages": messages,
59
+ "max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens),
60
+ "temperature": float(temperature or self.default_generation_temperature),
61
+ }
62
+
63
+ url = self.url.rstrip("/") + "/chat/completions"
64
+ headers = {
65
+ "Authorization": f"Bearer {self.api_key}",
66
+ "Content-Type": "application/json",
67
+ }
68
+
69
+ resp = requests.post(url, json=payload, headers=headers, timeout=6000)
70
+ if resp.status_code != 200:
71
+ self.logger.error("DeepSeek generate failed: %s %s", resp.status_code, resp.text)
72
+ return None
73
+
74
+ data = resp.json()
75
+
76
+ try:
77
+ generated_text = data["choices"][0]["message"]["content"].strip()
78
+ except (KeyError, IndexError, TypeError):
79
+ self.logger.error("Unexpected DeepSeek response structure: %s", data)
80
+ return None
81
+
82
+ if not generated_text:
83
+ return None
84
+
85
+ usage = data.get("usage", {})
86
+ return {
87
+ "model": data.get("model"),
88
+ "response": generated_text,
89
+ "tokens_generated": usage.get("completion_tokens"),
90
+ "total_duration_ms": None,
91
+ "prompt_eval_tokens": usage.get("prompt_tokens"),
92
+ }
93
+
94
+ except Exception as e:
95
+ self.logger.exception("Error in DeepSeekProvider.generate_text: %s", e)
96
+ return None
97
+
98
+ def embed_text(self, text: str, document_type: str = None):
99
+ """DeepSeek does not currently offer an embeddings endpoint — returns None."""
100
+ self.logger.warning("DeepSeekProvider does not support embeddings.")
101
+ return None
102
+
103
+ def construct_prompt(self, prompt: str, role: str):
104
+ return {
105
+ "role": role,
106
+ "content": self.process_text(prompt)
107
+ }
108
+
109
+ def embed_text_batch(self, texts: list[str], batch_size: int = 32):
110
+ """DeepSeek does not currently offer an embeddings endpoint — returns None."""
111
+ self.logger.warning("DeepSeekProvider does not support embeddings.")
112
+ return None
113
+
114
+ def clean_content(self, text: str) -> str:
115
+ text = re.sub(r'\[.*?\]\(.*?\)', '', text)
116
+ text = re.sub(r'\[[^\]]*\]', '', text)
117
+ text = re.sub(r'\n+', '\n', text).strip()
118
+ return text
119
+
120
+ def web_search(self, query: str):
121
+ """DeepSeek has no native web search — returns a not-supported notice."""
122
+ self.logger.warning("DeepSeekProvider.web_search is not natively supported.")
123
+ return {
124
+ "text": "Web search is not natively supported by the DeepSeek API.",
125
+ "references": []
126
+ }
stores/llm/providers/GeminiProvider.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ from stores.llm.LLMInterface import LLMInterface
4
+ import logging
5
+ import requests
6
+ import re
7
+ import os
8
+
9
+
10
+ class GeminiProvider(LLMInterface):
11
+ def __init__(self, url: str = None, model: str = None,
12
+ default_input_max_characters: int = 1000,
13
+ default_generation_max_output_tokens: int = 1000,
14
+ default_generation_temperature: float = 0.1, api_key: str = None):
15
+ self.url = url or "https://generativelanguage.googleapis.com/v1beta"
16
+ self.api_key = api_key or os.getenv("GEMINI_API_KEY")
17
+ self.model = model
18
+ self.generation_model_id = None
19
+
20
+ self.embedding_model = None
21
+ self.embedding_model_id = None
22
+ self.embedding_size = None
23
+
24
+ self.default_input_max_characters = default_input_max_characters
25
+ self.default_generation_max_output_tokens = default_generation_max_output_tokens
26
+ self.default_generation_temperature = default_generation_temperature
27
+ self.logger = logging.getLogger(__name__)
28
+
29
+ def set_generation_model(self, model_id: str):
30
+ if model_id:
31
+ self.model = model_id
32
+
33
+ def set_embedding_model(self, model_id: str, embedding_size: int):
34
+ if model_id:
35
+ self.embedding_model = model_id
36
+ self.embedding_size = embedding_size
37
+ self.embedding_model_id = model_id
38
+
39
+ def process_text(self, text: str):
40
+ if not text:
41
+ return ""
42
+ return str(text).strip()
43
+
44
+ def _build_contents(self, prompt: str, chat_history: list) -> list:
45
+ """Convert chat_history + prompt into Gemini's contents format."""
46
+ contents = []
47
+ for entry in chat_history:
48
+ role = entry.get("role", "user")
49
+ # Gemini uses 'model' instead of 'assistant'
50
+ if role == "assistant":
51
+ role = "model"
52
+ contents.append({
53
+ "role": role,
54
+ "parts": [{"text": entry.get("content", "")}]
55
+ })
56
+ contents.append({
57
+ "role": "user",
58
+ "parts": [{"text": prompt}]
59
+ })
60
+ return contents
61
+
62
+ def generate_text(self, prompt: str, chat_history: list = None,
63
+ max_output_tokens: int = None, temperature: float = None):
64
+ try:
65
+ chat_history = chat_history or []
66
+ clean_prompt = self.process_text(prompt)
67
+
68
+ contents = self._build_contents(clean_prompt, chat_history)
69
+
70
+ payload = {
71
+ "contents": contents,
72
+ "generationConfig": {
73
+ "maxOutputTokens": int(max_output_tokens or self.default_generation_max_output_tokens),
74
+ "temperature": float(temperature or self.default_generation_temperature),
75
+ }
76
+ }
77
+
78
+ url = (
79
+ f"{self.url.rstrip('/')}/models/{self.model}"
80
+ f":generateContent?key={self.api_key}"
81
+ )
82
+ headers = {"Content-Type": "application/json"}
83
+
84
+ resp = requests.post(url, json=payload, headers=headers, timeout=6000)
85
+ if resp.status_code != 200:
86
+ self.logger.error("Gemini generate failed: %s %s", resp.status_code, resp.text)
87
+ return None
88
+
89
+ data = resp.json()
90
+
91
+ try:
92
+ generated_text = (
93
+ data["candidates"][0]["content"]["parts"][0]["text"].strip()
94
+ )
95
+ except (KeyError, IndexError, TypeError):
96
+ self.logger.error("Unexpected Gemini response structure: %s", data)
97
+ return None
98
+
99
+ if not generated_text:
100
+ return None
101
+
102
+ usage = data.get("usageMetadata", {})
103
+ return {
104
+ "model": self.model,
105
+ "response": generated_text,
106
+ "tokens_generated": usage.get("candidatesTokenCount"),
107
+ "total_duration_ms": None,
108
+ "prompt_eval_tokens": usage.get("promptTokenCount"),
109
+ }
110
+
111
+ except Exception as e:
112
+ self.logger.exception("Error in GeminiProvider.generate_text: %s", e)
113
+ return None
114
+
115
+ def embed_text(self, text: str, document_type: str = None):
116
+ try:
117
+ if not self.embedding_model:
118
+ self.logger.error("Embedding model is not set before calling embed_text()")
119
+ return None
120
+
121
+ clean_text = self.process_text(text)
122
+ print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'")
123
+ if not clean_text:
124
+ return []
125
+
126
+ # Map document_type to Gemini task type
127
+ task_type_map = {
128
+ "search_document": "RETRIEVAL_DOCUMENT",
129
+ "search_query": "RETRIEVAL_QUERY",
130
+ "classification": "CLASSIFICATION",
131
+ "clustering": "CLUSTERING",
132
+ }
133
+ task_type = task_type_map.get(document_type, "RETRIEVAL_DOCUMENT")
134
+
135
+ payload = {
136
+ "model": f"models/{self.embedding_model}",
137
+ "content": {"parts": [{"text": clean_text}]},
138
+ "output_dimensionality": 768,
139
+ "taskType": task_type,
140
+ }
141
+
142
+ url = (
143
+ f"{self.url.rstrip('/')}/models/{self.embedding_model}"
144
+ f":embedContent?key={self.api_key}"
145
+ )
146
+ headers = {"Content-Type": "application/json"}
147
+
148
+ resp = requests.post(url, json=payload, headers=headers, timeout=200)
149
+ if resp.status_code != 200:
150
+ print(f"[ERROR] Gemini embedding failed: {resp.status_code} {resp.text}")
151
+ return None
152
+
153
+ data = resp.json()
154
+
155
+ try:
156
+ embedding = data["embedding"]["values"]
157
+ print(f"[DEBUG] Embedding length: {len(embedding)}")
158
+ return embedding
159
+ except (KeyError, TypeError):
160
+ print("[WARNING] 'embedding' key not found in response JSON")
161
+ return None
162
+
163
+ except Exception as e:
164
+ print(f"[EXCEPTION] Error in GeminiProvider.embed_text: {e}")
165
+ return None
166
+
167
+ def construct_prompt(self, prompt: str, role: str):
168
+ return {
169
+ "role": role,
170
+ "content": self.process_text(prompt)
171
+ }
172
+
173
+ def embed_text_batch(self, texts: list[str], batch_size: int = 32):
174
+ self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}")
175
+
176
+ if not self.embedding_model:
177
+ self.logger.error("Embedding model not set")
178
+ return None
179
+
180
+ all_embeddings = []
181
+
182
+ url = (
183
+ f"{self.url.rstrip('/')}/models/{self.embedding_model}"
184
+ f":batchEmbedContents?key={self.api_key}"
185
+ )
186
+ headers = {"Content-Type": "application/json"}
187
+
188
+ for i in range(0, len(texts), batch_size):
189
+ time.sleep(5)
190
+ batch = texts[i:i + batch_size]
191
+ clean_batch = [self.process_text(t) for t in batch if t]
192
+
193
+ print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size}")
194
+
195
+ # Gemini batchEmbedContents takes a list of requests
196
+ requests_list = [
197
+ {
198
+ "model": f"models/{self.embedding_model}",
199
+ "content": {"parts": [{"text": t}]},
200
+ "taskType": "RETRIEVAL_DOCUMENT",
201
+ "output_dimensionality": 768, # ← add this
202
+ }
203
+ for t in clean_batch
204
+ ]
205
+ payload = {"requests": requests_list}
206
+
207
+ resp = requests.post(url, json=payload, headers=headers, timeout=200)
208
+ if resp.status_code != 200:
209
+ self.logger.error("Gemini embedding failed: %s %s", resp.status_code, resp.text)
210
+ return None
211
+
212
+ data = resp.json()
213
+
214
+ try:
215
+ embeddings = [item["values"] for item in data["embeddings"]]
216
+ except (KeyError, TypeError):
217
+ self.logger.error("No embeddings returned from Gemini")
218
+ return None
219
+
220
+ if not embeddings:
221
+ self.logger.error("No embeddings returned from Gemini")
222
+ return None
223
+
224
+ self.logger.debug(f"Received {len(embeddings)} embeddings")
225
+ all_embeddings.extend(embeddings)
226
+
227
+ self.logger.info(f"Total embeddings created: {len(all_embeddings)}")
228
+ return all_embeddings
229
+
230
+ def clean_content(self, text: str) -> str:
231
+ text = re.sub(r'\[.*?\]\(.*?\)', '', text)
232
+ text = re.sub(r'\[[^\]]*\]', '', text)
233
+ text = re.sub(r'\n+', '\n', text).strip()
234
+ return text
235
+
236
+ def web_search(self, query: str):
237
+ """
238
+ Gemini supports Google Search grounding via the tools parameter.
239
+ Uses generateContent with the googleSearch tool enabled.
240
+ """
241
+ try:
242
+ payload = {
243
+ "contents": [{"role": "user", "parts": [{"text": query}]}],
244
+ "tools": [{"google_search": {}}],
245
+ "generationConfig": {
246
+ "maxOutputTokens": int(self.default_generation_max_output_tokens),
247
+ "temperature": float(self.default_generation_temperature),
248
+ }
249
+ }
250
+
251
+ url = (
252
+ f"{self.url.rstrip('/')}/models/{self.model}"
253
+ f":generateContent?key={self.api_key}"
254
+ )
255
+ headers = {"Content-Type": "application/json"}
256
+
257
+ resp = requests.post(url, json=payload, headers=headers, timeout=6000)
258
+
259
+ if not resp or resp.status_code != 200:
260
+ return {
261
+ "text": "No relevant external results found.",
262
+ "references": []
263
+ }
264
+
265
+ data = resp.json()
266
+
267
+ combined_text = []
268
+ references = set()
269
+
270
+ try:
271
+ text_content = data["candidates"][0]["content"]["parts"][0]["text"]
272
+ combined_text.append(self.clean_content(text_content))
273
+ except (KeyError, IndexError, TypeError):
274
+ pass
275
+
276
+ # Extract grounding metadata URLs
277
+ try:
278
+ chunks = (
279
+ data["candidates"][0]
280
+ .get("groundingMetadata", {})
281
+ .get("groundingChunks", [])
282
+ )
283
+ for chunk in chunks:
284
+ web = chunk.get("web", {})
285
+ uri = web.get("uri", "")
286
+ if uri.startswith("http"):
287
+ references.add(uri)
288
+ except (KeyError, IndexError, TypeError):
289
+ pass
290
+
291
+ # Also scan response text for bare URLs
292
+ for found_url in re.findall(r"https?://[^\s)]+", "\n".join(combined_text)):
293
+ references.add(found_url)
294
+
295
+ return {
296
+ "text": "\n\n".join(combined_text[:3]),
297
+ "references": list(references)
298
+ }
299
+
300
+ except Exception as e:
301
+ self.logger.error("Gemini web search failed: %s", e)
302
+ return {
303
+ "text": f"Gemini search error: {str(e)}",
304
+ "references": []
305
+ }
stores/llm/providers/GroqProvider.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stores.llm.LLMInterface import LLMInterface
2
+ import logging
3
+ import requests
4
+ import re
5
+ import os
6
+
7
+
8
+ class GroqProvider(LLMInterface):
9
+ def __init__(self, url: str = None, model: str = None,
10
+ default_input_max_characters: int = 1000,
11
+ default_generation_max_output_tokens: int = 1000,
12
+ default_generation_temperature: float = 0.1, api_key: str = None):
13
+ self.url = url or "https://api.groq.com/openai/v1"
14
+ self.api_key = api_key or os.getenv("GROQ_API_KEY")
15
+ self.model = model
16
+ self.generation_model_id = None
17
+
18
+ self.embedding_model = None
19
+ self.embedding_model_id = None
20
+ self.embedding_size = None
21
+
22
+ self.default_input_max_characters = default_input_max_characters
23
+ self.default_generation_max_output_tokens = default_generation_max_output_tokens
24
+ self.default_generation_temperature = default_generation_temperature
25
+ self.logger = logging.getLogger(__name__)
26
+
27
+ def set_generation_model(self, model_id: str):
28
+ if model_id:
29
+ self.model = model_id
30
+
31
+ def set_embedding_model(self, model_id: str, embedding_size: int):
32
+ if model_id:
33
+ self.embedding_model = model_id
34
+ self.embedding_size = embedding_size
35
+ self.embedding_model_id = model_id
36
+
37
+ def process_text(self, text: str):
38
+ if not text:
39
+ return ""
40
+ return str(text).strip()
41
+
42
+ def generate_text(self, prompt: str, chat_history: list = None,
43
+ max_output_tokens: int = None, temperature: float = None):
44
+ try:
45
+ chat_history = chat_history or []
46
+ clean_prompt = self.process_text(prompt)
47
+
48
+ messages = []
49
+ for entry in chat_history:
50
+ messages.append({
51
+ "role": entry.get("role", "user"),
52
+ "content": entry.get("content", "")
53
+ })
54
+ messages.append({"role": "user", "content": clean_prompt})
55
+
56
+ payload = {
57
+ "model": self.model,
58
+ "messages": messages,
59
+ "max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens),
60
+ "temperature": float(temperature or self.default_generation_temperature),
61
+ }
62
+
63
+ url = self.url.rstrip("/") + "/chat/completions"
64
+ headers = {
65
+ "Authorization": f"Bearer {self.api_key}",
66
+ "Content-Type": "application/json",
67
+ }
68
+
69
+ resp = requests.post(url, json=payload, headers=headers, timeout=6000)
70
+ if resp.status_code != 200:
71
+ self.logger.error("Groq generate failed: %s %s", resp.status_code, resp.text)
72
+ return None
73
+
74
+ data = resp.json()
75
+
76
+ try:
77
+ generated_text = data["choices"][0]["message"]["content"].strip()
78
+ except (KeyError, IndexError, TypeError):
79
+ self.logger.error("Unexpected Groq response structure: %s", data)
80
+ return None
81
+
82
+ if not generated_text:
83
+ return None
84
+
85
+ usage = data.get("usage", {})
86
+ # Groq exposes x_groq.usage.total_time in seconds
87
+ total_time_ms = None
88
+ try:
89
+ total_time_ms = round(data["x_groq"]["usage"]["total_time"] * 1000, 2)
90
+ except (KeyError, TypeError):
91
+ pass
92
+
93
+ return {
94
+ "model": data.get("model"),
95
+ "response": generated_text,
96
+ "tokens_generated": usage.get("completion_tokens"),
97
+ "total_duration_ms": total_time_ms,
98
+ "prompt_eval_tokens": usage.get("prompt_tokens"),
99
+ }
100
+
101
+ except Exception as e:
102
+ self.logger.exception("Error in GroqProvider.generate_text: %s", e)
103
+ return None
104
+
105
+ def embed_text(self, text: str, document_type: str = None):
106
+ """Groq does not support embeddings — returns None."""
107
+ self.logger.warning("GroqProvider does not support embeddings.")
108
+ return None
109
+
110
+ def construct_prompt(self, prompt: str, role: str):
111
+ return {
112
+ "role": role,
113
+ "content": self.process_text(prompt)
114
+ }
115
+
116
+ def embed_text_batch(self, texts: list[str], batch_size: int = 32):
117
+ """Groq does not support embeddings — returns None."""
118
+ self.logger.warning("GroqProvider does not support embeddings.")
119
+ return None
120
+
121
+ def clean_content(self, text: str) -> str:
122
+ text = re.sub(r'\[.*?\]\(.*?\)', '', text)
123
+ text = re.sub(r'\[[^\]]*\]', '', text)
124
+ text = re.sub(r'\n+', '\n', text).strip()
125
+ return text
126
+
127
+ def web_search(self, query: str):
128
+ """Groq has no native web search — returns a not-supported notice."""
129
+ self.logger.warning("GroqProvider.web_search is not natively supported.")
130
+ return {
131
+ "text": "Web search is not natively supported by the Groq API.",
132
+ "references": []
133
+ }
stores/llm/providers/HuggingFaceProvider.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stores.llm.LLMInterface import LLMInterface
2
+ import logging
3
+ import requests
4
+ import re
5
+ import os
6
+
7
+
8
+ class HuggingFaceProvider(LLMInterface):
9
+ def __init__(self, url: str = None, model: str = None,
10
+ default_input_max_characters: int = 1000,
11
+ default_generation_max_output_tokens: int = 1000,
12
+ default_generation_temperature: float = 0.1, api_key: str = None):
13
+ # Supports both Inference API (serverless) and Inference Endpoints (dedicated)
14
+ self.url = url or "https://router.huggingface.co"
15
+ self.api_key = api_key or os.getenv("HF_API_KEY")
16
+ self.model = model
17
+ self.generation_model_id = None
18
+
19
+ self.embedding_model = None
20
+ self.embedding_model_id = None
21
+ self.embedding_size = None
22
+
23
+ self.default_input_max_characters = default_input_max_characters
24
+ self.default_generation_max_output_tokens = default_generation_max_output_tokens
25
+ self.default_generation_temperature = default_generation_temperature
26
+ self.logger = logging.getLogger(__name__)
27
+
28
+ def set_generation_model(self, model_id: str):
29
+ if model_id:
30
+ self.model = model_id
31
+
32
+ def set_embedding_model(self, model_id: str, embedding_size: int):
33
+ if model_id:
34
+ self.embedding_model = model_id
35
+ self.embedding_size = embedding_size
36
+ self.embedding_model_id = model_id
37
+
38
+ def process_text(self, text: str):
39
+ if not text:
40
+ return ""
41
+ return str(text).strip()
42
+
43
+ def generate_text(self, prompt: str, chat_history: list = None,
44
+ max_output_tokens: int = None, temperature: float = None):
45
+ try:
46
+ chat_history = chat_history or []
47
+ clean_prompt = self.process_text(prompt)
48
+
49
+ messages = []
50
+ for entry in chat_history:
51
+ messages.append({
52
+ "role": entry.get("role", "user"),
53
+ "content": entry.get("content", "")
54
+ })
55
+ messages.append({"role": "user", "content": clean_prompt})
56
+
57
+ payload = {
58
+ "model": self.model,
59
+ "messages": messages,
60
+ "max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens),
61
+ "temperature": float(temperature or self.default_generation_temperature),
62
+ }
63
+
64
+ # HF Inference API (serverless): /v1/chat/completions (OpenAI-compatible)
65
+ url = self.url.rstrip("/") + "/v1/chat/completions"
66
+ headers = {
67
+ "Authorization": f"Bearer {self.api_key}",
68
+ "Content-Type": "application/json",
69
+ }
70
+
71
+ resp = requests.post(url, json=payload, headers=headers, timeout=6000)
72
+ if resp.status_code != 200:
73
+ self.logger.error("HuggingFace generate failed: %s %s", resp.status_code, resp.text)
74
+ return None
75
+
76
+ data = resp.json()
77
+
78
+ try:
79
+ generated_text = data["choices"][0]["message"]["content"].strip()
80
+ except (KeyError, IndexError, TypeError):
81
+ self.logger.error("Unexpected HuggingFace response structure: %s", data)
82
+ return None
83
+
84
+ if not generated_text:
85
+ return None
86
+
87
+ usage = data.get("usage", {})
88
+ return {
89
+ "model": data.get("model"),
90
+ "response": generated_text,
91
+ "tokens_generated": usage.get("completion_tokens"),
92
+ "total_duration_ms": None,
93
+ "prompt_eval_tokens": usage.get("prompt_tokens"),
94
+ }
95
+
96
+ except Exception as e:
97
+ self.logger.exception("Error in HuggingFaceProvider.generate_text: %s", e)
98
+ return None
99
+
100
+ def embed_text(self, text: str, document_type: str = None):
101
+ try:
102
+ if not self.embedding_model:
103
+ self.logger.error("Embedding model is not set before calling embed_text()")
104
+ return None
105
+
106
+ clean_text = self.process_text(text)
107
+ print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'")
108
+ if not clean_text:
109
+ return []
110
+
111
+ payload = {"inputs": clean_text}
112
+
113
+ # Feature-extraction endpoint per model
114
+ url = f"https://router.huggingface.co/hf-inference/models/{self.embedding_model}/pipeline/feature-extraction"
115
+ headers = {
116
+ "Authorization": f"Bearer {self.api_key}",
117
+ "Content-Type": "application/json",
118
+ }
119
+
120
+ resp = requests.post(url, json=payload, headers=headers, timeout=200)
121
+ if resp.status_code != 200:
122
+ print(f"[ERROR] HuggingFace embedding failed: {resp.status_code} {resp.text}")
123
+ return None
124
+
125
+ data = resp.json()
126
+
127
+ # HF returns a nested list: [[vector]] for single input
128
+ embedding = None
129
+ if isinstance(data, list):
130
+ if len(data) > 0 and isinstance(data[0], list):
131
+ embedding = data[0] # [[float, ...]] -> [float, ...]
132
+ elif len(data) > 0 and isinstance(data[0], float):
133
+ embedding = data # [float, ...] already flat
134
+ elif isinstance(data, dict) and "embedding" in data:
135
+ embedding = data["embedding"]
136
+
137
+ if embedding is not None:
138
+ print(f"[DEBUG] Embedding length: {len(embedding)}")
139
+ return embedding
140
+
141
+ print("[WARNING] 'embedding' key not found in response JSON")
142
+ return None
143
+
144
+ except Exception as e:
145
+ print(f"[EXCEPTION] Error in HuggingFaceProvider.embed_text: {e}")
146
+ return None
147
+
148
+ def construct_prompt(self, prompt: str, role: str):
149
+ return {
150
+ "role": role,
151
+ "content": self.process_text(prompt)
152
+ }
153
+
154
+ def embed_text_batch(self, texts: list[str], batch_size: int = 32):
155
+ self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}")
156
+
157
+ if not self.embedding_model:
158
+ self.logger.error("Embedding model not set")
159
+ return None
160
+
161
+ all_embeddings = []
162
+
163
+ url = f"https://router.huggingface.co/hf-inference/models/{self.embedding_model}/pipeline/feature-extraction"
164
+ headers = {
165
+ "Authorization": f"Bearer {self.api_key}",
166
+ "Content-Type": "application/json",
167
+ }
168
+
169
+ for i in range(0, len(texts), batch_size):
170
+ batch = texts[i:i + batch_size]
171
+ clean_batch = [self.process_text(t) for t in batch if t]
172
+
173
+ print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size}")
174
+
175
+ payload = {"inputs": clean_batch}
176
+
177
+ resp = requests.post(url, json=payload, headers=headers, timeout=200)
178
+ if resp.status_code != 200:
179
+ self.logger.error("HuggingFace embedding failed: %s %s", resp.status_code, resp.text)
180
+ return None
181
+
182
+ data = resp.json()
183
+
184
+ # Batch response: [[vec1], [vec2], ...] or [[f,f,...], [f,f,...]]
185
+ embeddings = None
186
+ if isinstance(data, list) and len(data) > 0:
187
+ if isinstance(data[0], list):
188
+ embeddings = data
189
+ elif isinstance(data[0], float):
190
+ embeddings = [data] # single vector returned flat
191
+
192
+ if not embeddings:
193
+ self.logger.error("No embeddings returned from HuggingFace")
194
+ return None
195
+
196
+ self.logger.debug(f"Received {len(embeddings)} embeddings")
197
+ all_embeddings.extend(embeddings)
198
+
199
+ self.logger.info(f"Total embeddings created: {len(all_embeddings)}")
200
+ return all_embeddings
201
+
202
+ def clean_content(self, text: str) -> str:
203
+ text = re.sub(r'\[.*?\]\(.*?\)', '', text)
204
+ text = re.sub(r'\[[^\]]*\]', '', text)
205
+ text = re.sub(r'\n+', '\n', text).strip()
206
+ return text
207
+
208
+ def web_search(self, query: str):
209
+ """HuggingFace Inference API has no native web search — returns a not-supported notice."""
210
+ self.logger.warning("HuggingFaceProvider.web_search is not natively supported.")
211
+ return {
212
+ "text": "Web search is not natively supported by the HuggingFace Inference API.",
213
+ "references": []
214
+ }
stores/llm/providers/MistralProvider.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ from stores.llm.LLMInterface import LLMInterface
4
+ import logging
5
+ import requests
6
+ import re
7
+ import os
8
+
9
+
10
+ class MistralProvider(LLMInterface):
11
+ def __init__(self, url: str = None, model: str = None,
12
+ default_input_max_characters: int = 1000,
13
+ default_generation_max_output_tokens: int = 1000,
14
+ default_generation_temperature: float = 0.1, api_key: str = None):
15
+ self.url = url or "https://api.mistral.ai/v1"
16
+ self.api_key = api_key or os.getenv("MISTRAL_API_KEY")
17
+ self.model = model
18
+ self.generation_model_id = None
19
+
20
+ self.embedding_model = None
21
+ self.embedding_model_id = None
22
+ self.embedding_size = None
23
+
24
+ self.default_input_max_characters = default_input_max_characters
25
+ self.default_generation_max_output_tokens = default_generation_max_output_tokens
26
+ self.default_generation_temperature = default_generation_temperature
27
+ self.logger = logging.getLogger(__name__)
28
+
29
+ def set_generation_model(self, model_id: str):
30
+ if model_id:
31
+ self.model = model_id
32
+
33
+ def set_embedding_model(self, model_id: str, embedding_size: int):
34
+ if model_id:
35
+ self.embedding_model = model_id
36
+ self.embedding_size = embedding_size
37
+ self.embedding_model_id = model_id
38
+
39
+ def process_text(self, text: str):
40
+ if not text:
41
+ return ""
42
+ return str(text).strip()
43
+
44
+ def generate_text(self, prompt: str, chat_history: list = None,
45
+ max_output_tokens: int = None, temperature: float = None):
46
+ try:
47
+ chat_history = chat_history or []
48
+ clean_prompt = self.process_text(prompt)
49
+
50
+ messages = []
51
+ for entry in chat_history:
52
+ messages.append({
53
+ "role": entry.get("role", "user"),
54
+ "content": entry.get("content", "")
55
+ })
56
+ messages.append({"role": "user", "content": clean_prompt})
57
+
58
+ payload = {
59
+ "model": self.model,
60
+ "messages": messages,
61
+ "max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens),
62
+ "temperature": float(temperature or self.default_generation_temperature),
63
+ }
64
+
65
+ url = self.url.rstrip("/") + "/chat/completions"
66
+ headers = {
67
+ "Authorization": f"Bearer {self.api_key}",
68
+ "Content-Type": "application/json",
69
+ }
70
+
71
+ resp = requests.post(url, json=payload, headers=headers, timeout=6000)
72
+ if resp.status_code != 200:
73
+ self.logger.error("Mistral generate failed: %s %s", resp.status_code, resp.text)
74
+ return None
75
+
76
+ data = resp.json()
77
+
78
+ try:
79
+ generated_text = data["choices"][0]["message"]["content"].strip()
80
+ except (KeyError, IndexError, TypeError):
81
+ self.logger.error("Unexpected Mistral response structure: %s", data)
82
+ return None
83
+
84
+ if not generated_text:
85
+ return None
86
+
87
+ usage = data.get("usage", {})
88
+ return {
89
+ "model": data.get("model"),
90
+ "response": generated_text,
91
+ "tokens_generated": usage.get("completion_tokens"),
92
+ "total_duration_ms": None,
93
+ "prompt_eval_tokens": usage.get("prompt_tokens"),
94
+ }
95
+
96
+ except Exception as e:
97
+ self.logger.exception("Error in MistralProvider.generate_text: %s", e)
98
+ return None
99
+
100
+ def embed_text(self, text: str, document_type: str = None):
101
+ try:
102
+ if not self.embedding_model:
103
+ self.logger.error("Embedding model is not set before calling embed_text()")
104
+ return None
105
+
106
+ clean_text = self.process_text(text)
107
+ print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'")
108
+ if not clean_text:
109
+ return []
110
+
111
+ payload = {
112
+ "model": self.embedding_model,
113
+ "input": [clean_text],
114
+ }
115
+
116
+ url = self.url.rstrip("/") + "/embeddings"
117
+ headers = {
118
+ "Authorization": f"Bearer {self.api_key}",
119
+ "Content-Type": "application/json",
120
+ }
121
+
122
+ resp = requests.post(url, json=payload, headers=headers, timeout=200)
123
+ if resp.status_code != 200:
124
+ print(f"[ERROR] Mistral embedding failed: {resp.status_code} {resp.text}")
125
+ return None
126
+
127
+ data = resp.json()
128
+
129
+ try:
130
+ embedding = data["data"][0]["embedding"]
131
+ print(f"[DEBUG] Embedding length: {len(embedding)}")
132
+ return embedding
133
+ except (KeyError, IndexError, TypeError):
134
+ print("[WARNING] 'embedding' key not found in response JSON")
135
+ return None
136
+
137
+ except Exception as e:
138
+ print(f"[EXCEPTION] Error in MistralProvider.embed_text: {e}")
139
+ return None
140
+
141
+ def construct_prompt(self, prompt: str, role: str):
142
+ return {
143
+ "role": role,
144
+ "content": self.process_text(prompt)
145
+ }
146
+
147
+ def embed_text_batch(self, texts: list[str], batch_size: int = 32):
148
+ self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}")
149
+
150
+ if not self.embedding_model:
151
+ self.logger.error("Embedding model not set")
152
+ return None
153
+
154
+ all_embeddings = []
155
+ url = self.url.rstrip("/") + "/embeddings"
156
+ headers = {
157
+ "Authorization": f"Bearer {self.api_key}",
158
+ "Content-Type": "application/json",
159
+ }
160
+
161
+ for i in range(0, len(texts), batch_size):
162
+ time.sleep(5)
163
+ batch = texts[i:i + batch_size]
164
+ clean_batch = [self.process_text(t) for t in batch if t]
165
+
166
+ print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size}")
167
+
168
+ payload = {
169
+ "model": self.embedding_model,
170
+ "input": clean_batch,
171
+ }
172
+
173
+ resp = requests.post(url, json=payload, headers=headers, timeout=200)
174
+ if resp.status_code != 200:
175
+ self.logger.error("Mistral embedding failed: %s %s", resp.status_code, resp.text)
176
+ return None
177
+
178
+ data = resp.json()
179
+
180
+ try:
181
+ embeddings = [item["embedding"] for item in data["data"]]
182
+ except (KeyError, TypeError):
183
+ self.logger.error("No embeddings returned from Mistral")
184
+ return None
185
+
186
+ if not embeddings:
187
+ self.logger.error("No embeddings returned from Mistral")
188
+ return None
189
+
190
+ self.logger.debug(f"Received {len(embeddings)} embeddings")
191
+ all_embeddings.extend(embeddings)
192
+
193
+ self.logger.info(f"Total embeddings created: {len(all_embeddings)}")
194
+ return all_embeddings
195
+
196
+ def clean_content(self, text: str) -> str:
197
+ text = re.sub(r'\[.*?\]\(.*?\)', '', text)
198
+ text = re.sub(r'\[[^\]]*\]', '', text)
199
+ text = re.sub(r'\n+', '\n', text).strip()
200
+ return text
201
+
202
+ def web_search(self, query: str):
203
+ """Mistral has no native web search — returns a not-supported notice."""
204
+ self.logger.warning("MistralProvider.web_search is not natively supported.")
205
+ return {
206
+ "text": "Web search is not natively supported by the Mistral API.",
207
+ "references": []
208
+ }
stores/llm/providers/OllamaProvider.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stores.llm.LLMInterface import LLMInterface
2
+ import logging
3
+ import requests
4
+ import re
5
+ import ollama
6
+ import os
7
+ class OllamaProvider(LLMInterface):
8
+ def __init__(self, url: str=None, model: str=None,
9
+ default_input_max_characters: int=1000,
10
+ default_generation_max_output_tokens: int=1000,
11
+ default_generation_temperature: float=0.1, api_key: str=None):
12
+ self.url = url or "http://localhost:11434"
13
+ self.api_key = api_key or os.getenv("OLLAMA_API_KEY")
14
+ self.model = model
15
+ self.generation_model_id = None
16
+
17
+ self.embedding_model = None
18
+ self.embedding_model_id = None
19
+ self.embedding_size = None
20
+
21
+ self.default_input_max_characters = default_input_max_characters
22
+ self.default_generation_max_output_tokens = default_generation_max_output_tokens
23
+ self.default_generation_temperature = default_generation_temperature
24
+ self.logger = logging.getLogger(__name__)
25
+
26
+ def set_generation_model(self, model_id: str):
27
+ if model_id:
28
+ self.model = model_id
29
+
30
+ def set_embedding_model(self, model_id: str, embedding_size: int):
31
+ if model_id:
32
+ self.embedding_model = model_id
33
+ self.embedding_size = embedding_size
34
+ self.embedding_model_id = model_id
35
+
36
+ def process_text(self, text: str):
37
+ if not text:
38
+ return ""
39
+ return str(text).strip()
40
+
41
+ def generate_text(self, prompt: str, chat_history: list = None,
42
+ max_output_tokens: int = None, temperature: float = None):
43
+
44
+
45
+ try:
46
+ chat_history = chat_history or [] # safe handling
47
+ clean_prompt = self.process_text(prompt)
48
+
49
+ # Build payload with correct Ollama keys
50
+ payload = {
51
+ "model": self.model,
52
+ "prompt": clean_prompt,
53
+ "stream": False,
54
+ "num_predict": int(max_output_tokens or self.default_generation_max_output_tokens),
55
+ "temperature": float(temperature or self.default_generation_temperature),
56
+ }
57
+
58
+ url = self.url.rstrip("/") + "/api/generate"
59
+ headers = {}
60
+ if self.api_key:
61
+ headers["Authorization"] = f"Bearer {self.api_key}"
62
+ resp = requests.post(url, json=payload, headers=headers, timeout=6000)
63
+ if resp.status_code != 200:
64
+ self.logger.error("Ollama generate failed: %s %s", resp.status_code, resp.text)
65
+ return None
66
+
67
+ data = resp.json()
68
+
69
+ # Extract final generated text correctly
70
+ generated_text = data.get("response", "").strip()
71
+
72
+ # If nothing generated, treat as failure
73
+ if not generated_text:
74
+ return None
75
+
76
+ # Return clean JSON instead of raw text
77
+ return {
78
+ "model": data.get("model"),
79
+ "response": generated_text,
80
+ "tokens_generated": data.get("eval_count"),
81
+ "total_duration_ms": round(data.get("total_duration", 0) / 1e6, 2),
82
+ "prompt_eval_tokens": data.get("prompt_eval_count"),
83
+ }
84
+
85
+ except Exception as e:
86
+ self.logger.exception("Error in OllamaProvider.generate_text: %s", e)
87
+ return None
88
+
89
+ def embed_text(self, text: str, document_type: str = None):
90
+ """Return an embedding vector from Ollama."""
91
+ try:
92
+ if not self.embedding_model:
93
+ self.logger.error("Embedding model is not set before calling embed_text()")
94
+ return None
95
+
96
+ clean_text = self.process_text(text)
97
+ print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'")
98
+ if not clean_text:
99
+ return []
100
+
101
+ payload = {
102
+ "model": self.embedding_model,
103
+ "input": clean_text
104
+ }
105
+
106
+ url = self.url.rstrip("/") + "/api/embed"
107
+ headers = {}
108
+ if self.api_key:
109
+ headers["Authorization"] = f"Bearer {self.api_key}"
110
+
111
+ resp = requests.post(url, json=payload, headers=headers, timeout=400)
112
+ if resp.status_code != 200:
113
+ print(f"[ERROR] Ollama embedding failed: {resp.status_code} {resp.text}")
114
+ return None
115
+
116
+ data = resp.json()
117
+
118
+ # Expected format: { "embedding": [...] }
119
+ if "embedding" in data:
120
+ print(f"[DEBUG] Embedding length: {len(data['embedding'])}")
121
+ return data["embedding"]
122
+ elif "embeddings" in data:
123
+ return data["embeddings"][0]
124
+
125
+ print("[WARNING] 'embedding' key not found in response JSON")
126
+ return None
127
+
128
+ except Exception as e:
129
+ print(f"[EXCEPTION] Error in OllamaProvider.embed_text: {e}")
130
+ return None
131
+
132
+ def construct_prompt(self, prompt: str, role: str):
133
+ return {
134
+ "role": role,
135
+ "content": self.process_text(prompt)
136
+ }
137
+
138
+ # def embed_text_batch(self, texts: list[str], batch_size: int = 32):
139
+
140
+ # self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}")
141
+
142
+ # if not self.embedding_model:
143
+ # self.logger.error("Embedding model not set")
144
+ # return None
145
+
146
+ # all_embeddings = []
147
+
148
+ # url = self.url.rstrip("/") + "/api/embed"
149
+
150
+ # for i in range(0, len(texts), batch_size):
151
+ # batch = texts[i:i + batch_size]
152
+
153
+ # clean_batch = [self.process_text(t) for t in batch if t]
154
+
155
+ # print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size} Progress = {i+batch_size}")
156
+
157
+ # payload = {
158
+ # "model": self.embedding_model,
159
+ # "input": clean_batch
160
+ # }
161
+
162
+ # resp = requests.post(url, json=payload, timeout=400)
163
+
164
+ # if resp.status_code != 200:
165
+ # self.logger.error("Ollama embedding failed: %s %s", resp.status_code, resp.text)
166
+ # return None
167
+
168
+ # data = resp.json()
169
+
170
+ # embeddings = data.get("embeddings")
171
+
172
+ # if not embeddings:
173
+ # self.logger.error("No embeddings returned from Ollama")
174
+ # return None
175
+
176
+ # self.logger.debug(f"Received {len(embeddings)} embeddings")
177
+
178
+ # all_embeddings.extend(embeddings)
179
+
180
+ # self.logger.info(f"Total embeddings created: {len(all_embeddings)}")
181
+
182
+ # return all_embeddings
183
+
184
+ def embed_text_batch(self, texts: list[str], batch_size: int = 64):
185
+ """
186
+ Batch embedding for a list of texts, compatible with both /api/embed (new) and /api/embeddings (legacy).
187
+ Logs progress and returns a list of embedding vectors.
188
+ """
189
+ all_embeddings = []
190
+
191
+ endpoints = ["/api/embed", "/api/embeddings"]
192
+ headers = {"Content-Type": "application/json"}
193
+ if self.api_key:
194
+ headers["Authorization"] = f"Bearer {self.api_key}"
195
+
196
+ total_texts = len(texts)
197
+ self.logger.info(f"Starting batch embedding of {total_texts} texts with batch_size={batch_size}")
198
+
199
+ for ep in endpoints:
200
+ try:
201
+ for i in range(0, total_texts, batch_size):
202
+ batch = texts[i:i + batch_size]
203
+ clean_batch = [self.process_text(t) for t in batch if t]
204
+
205
+ payload = {"model": self.embedding_model}
206
+
207
+ if ep == "/api/embed":
208
+ payload["input"] = clean_batch
209
+ resp = requests.post(self.url.rstrip("/") + ep, json=payload, headers=headers, timeout=400)
210
+ if resp.status_code != 200:
211
+ self.logger.warning(
212
+ "Batch embedding failed at %s: %s %s", ep, resp.status_code, resp.text
213
+ )
214
+ continue
215
+
216
+ data = resp.json()
217
+ embeddings = data.get("embeddings") or ([data.get("embedding")] if "embedding" in data else [])
218
+ all_embeddings.extend(embeddings)
219
+
220
+ else:
221
+ # Legacy endpoint: send individually
222
+ for j, t in enumerate(clean_batch):
223
+ payload_legacy = {"model": self.embedding_model, "prompt": t}
224
+ resp = requests.post(self.url.rstrip("/") + ep, json=payload_legacy, headers=headers, timeout=400)
225
+ if resp.status_code != 200:
226
+ self.logger.warning(
227
+ "Legacy embedding failed at %s: %s %s", ep, resp.status_code, resp.text
228
+ )
229
+ continue
230
+
231
+ data = resp.json()
232
+ if "embedding" in data:
233
+ all_embeddings.append(data["embedding"])
234
+ self.logger.info(f"Embedded {i+j+1}/{total_texts} texts using legacy endpoint")
235
+
236
+ # Log batch progress
237
+ self.logger.info(f"Embedded {min(i+batch_size, total_texts)}/{total_texts} texts using {ep}")
238
+
239
+ if all_embeddings:
240
+ self.logger.info(f"Finished embedding {len(all_embeddings)}/{total_texts} texts successfully")
241
+ break # stop after successful endpoint
242
+
243
+ except Exception as e:
244
+ self.logger.exception("Batch embedding error at %s: %s", ep, e)
245
+
246
+ return all_embeddings
247
+
248
+ def clean_content(self, text: str) -> str:
249
+ text = re.sub(r'\[.*?\]\(.*?\)', '', text)
250
+ text = re.sub(r'\[[^\]]*\]', '', text)
251
+ text = re.sub(r'\n+', '\n', text).strip()
252
+ return text
253
+
254
+ def web_search(self, query: str):
255
+ """Use Ollama client to perform web search and return cleaned text + references."""
256
+ try:
257
+ # Use your old working Ollama client
258
+ OLLAMA_API_KEY = os.getenv("OLLAMA_API_KEY")
259
+ ollama_client = ollama.Client(headers={'Authorization': 'Bearer ' + OLLAMA_API_KEY})
260
+ response = ollama_client.web_search(query)
261
+
262
+ if not response or "results" not in response or len(response["results"]) == 0:
263
+ return {
264
+ "text": "No relevant external results found.",
265
+ "references": []
266
+ }
267
+
268
+ combined_text = []
269
+ references = set()
270
+
271
+ for item in response["results"]:
272
+ text = self.clean_content(item.content)
273
+ combined_text.append(text)
274
+
275
+ urls = re.findall(r"https?://[^\s)]+", item.content)
276
+ for url in urls:
277
+ references.add(url)
278
+
279
+ if hasattr(item, "url") and item.url:
280
+ references.add(item.url)
281
+
282
+ return {
283
+ "text": "\n\n".join(combined_text[:3]),
284
+ "references": list(references)
285
+ }
286
+
287
+ except Exception as e:
288
+ self.logger.error("Ollama web search failed: %s", e)
289
+ return {
290
+ "text": f"Ollama search error: {str(e)}",
291
+ "references": []
292
+ }
stores/llm/providers/OpenAIProvider.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..LLMInterface import LLMInterface
2
+ from ..LLMEnums import OpenAIEnums
3
+ from openai import OpenAI
4
+ import logging
5
+
6
+ class OpenAIProvider(LLMInterface):
7
+ def __init__(self, api_key: str, api_url: str = None,
8
+ default_input_max_characters: int = 1000,
9
+ default_generation_max_output_tokens: int = 1000,
10
+ default_generation_temperature: float = 0.1):
11
+ self.api_key = api_key
12
+ self.api_url = api_url
13
+ self.default_input_max_characters = default_input_max_characters
14
+ self.default_generation_max_output_tokens = default_generation_max_output_tokens
15
+ self.default_generation_temperature = default_generation_temperature
16
+
17
+ self.generation_model_id = None
18
+ self.embedding_model_id = None
19
+ self.embedding_size = None
20
+
21
+ self.client = OpenAI(api_key=self.api_key, base_url=self.api_url)
22
+ self.logger = logging.getLogger(__name__)
23
+
24
+ def set_generation_model(self, model_id: str):
25
+ self.generation_model_id = model_id
26
+
27
+ def set_embedding_model(self, model_id: str, embedding_size: int):
28
+ self.embedding_model_id = model_id
29
+ self.embedding_size = embedding_size
30
+
31
+ def process_text(self, text: str):
32
+ return text[:self.default_input_max_characters].strip()
33
+
34
+ def generate_text(self, prompt: str, chat_history: list = None,
35
+ max_output_tokens: int = None, temperature: float = None):
36
+ if not self.client:
37
+ self.logger.error("OpenAI client was not initialized")
38
+ return None
39
+
40
+ if not self.generation_model_id:
41
+ self.logger.error("OpenAI generation model not set")
42
+ return None
43
+
44
+ max_output_tokens = max_output_tokens or self.default_generation_max_output_tokens
45
+ temperature = temperature or self.default_generation_temperature
46
+
47
+ messages = chat_history[:] if chat_history else []
48
+ messages.append(self.construct_prompt(prompt, OpenAIEnums.USER.value))
49
+
50
+ try:
51
+ response = self.client.chat.completions.create(
52
+ model=self.generation_model_id,
53
+ messages=messages,
54
+ max_completion_tokens=max_output_tokens,
55
+ temperature=temperature
56
+ )
57
+
58
+ if (not response or not response.choices
59
+ or not response.choices[0].message
60
+ or not response.choices[0].message.content):
61
+ self.logger.error("Invalid OpenAI response format")
62
+ return None
63
+
64
+ return response.choices[0].message.content
65
+
66
+ except Exception as e:
67
+ self.logger.exception("Error while generating text with OpenAI: %s", e)
68
+ return None
69
+
70
+ def embed_text_batch(self, texts: list[str], batch_size: int = 32):
71
+ pass
72
+
73
+ def embed_text(self, text: str, document_type: str = None):
74
+ if not self.client:
75
+ self.logger.error("OpenAI client was not initialized")
76
+ return None
77
+
78
+ if not self.embedding_model_id:
79
+ self.logger.error("OpenAI embedding model not set")
80
+ return None
81
+
82
+ try:
83
+ response = self.client.embeddings.create(
84
+ model=self.embedding_model_id,
85
+ input=text
86
+ )
87
+
88
+ if not response or not response.data or not response.data[0].embedding:
89
+ self.logger.error("Invalid OpenAI embedding response")
90
+ return None
91
+
92
+ return response.data[0].embedding
93
+
94
+ except Exception as e:
95
+ self.logger.exception("Error while embedding text with OpenAI: %s", e)
96
+ return None
97
+
98
+ def construct_prompt(self, prompt: str, role: str):
99
+ return {
100
+ "role": role,
101
+ "content": self.process_text(prompt)
102
+ }
stores/llm/providers/OpenRouterProvider.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stores.llm.LLMInterface import LLMInterface
2
+ import logging
3
+ import requests
4
+ import re
5
+ import os
6
+
7
+
8
+ class OpenRouterProvider(LLMInterface):
9
+ def __init__(self, url: str = None, model: str = None,
10
+ default_input_max_characters: int = 1000,
11
+ default_generation_max_output_tokens: int = 1000,
12
+ default_generation_temperature: float = 0.1, api_key: str = None):
13
+ self.url = url or "https://openrouter.ai/api/v1"
14
+ self.api_key = api_key or os.getenv("OPENROUTER_API_KEY")
15
+ self.model = model
16
+ self.generation_model_id = None
17
+
18
+ self.embedding_model = None
19
+ self.embedding_model_id = None
20
+ self.embedding_size = None
21
+
22
+ self.default_input_max_characters = default_input_max_characters
23
+ self.default_generation_max_output_tokens = default_generation_max_output_tokens
24
+ self.default_generation_temperature = default_generation_temperature
25
+ self.logger = logging.getLogger(__name__)
26
+
27
+ def set_generation_model(self, model_id: str):
28
+ if model_id:
29
+ self.model = model_id
30
+
31
+ def set_embedding_model(self, model_id: str, embedding_size: int):
32
+ if model_id:
33
+ self.embedding_model = model_id
34
+ self.embedding_size = embedding_size
35
+ self.embedding_model_id = model_id
36
+
37
+ def process_text(self, text: str):
38
+ if not text:
39
+ return ""
40
+ return str(text).strip()
41
+
42
+ def generate_text(self, prompt: str, chat_history: list = None,
43
+ max_output_tokens: int = None, temperature: float = None):
44
+ try:
45
+ chat_history = chat_history or []
46
+ clean_prompt = self.process_text(prompt)
47
+
48
+ messages = []
49
+ for entry in chat_history:
50
+ messages.append({
51
+ "role": entry.get("role", "user"),
52
+ "content": entry.get("content", "")
53
+ })
54
+ messages.append({"role": "user", "content": clean_prompt})
55
+
56
+ payload = {
57
+ "model": self.model,
58
+ "messages": messages,
59
+ "max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens),
60
+ "temperature": float(temperature or self.default_generation_temperature),
61
+ }
62
+
63
+ url = self.url.rstrip("/") + "/chat/completions"
64
+ headers = {
65
+ "Authorization": f"Bearer {self.api_key}",
66
+ "Content-Type": "application/json",
67
+ # Recommended by OpenRouter for usage tracking
68
+ "HTTP-Referer": os.getenv("OPENROUTER_SITE_URL", "http://localhost"),
69
+ "X-Title": os.getenv("OPENROUTER_APP_NAME", "LLMApp"),
70
+ }
71
+
72
+ resp = requests.post(url, json=payload, headers=headers, timeout=6000)
73
+ if resp.status_code != 200:
74
+ self.logger.error("OpenRouter generate failed: %s %s", resp.status_code, resp.text)
75
+ return None
76
+
77
+ data = resp.json()
78
+
79
+ try:
80
+ generated_text = data["choices"][0]["message"]["content"].strip()
81
+ except (KeyError, IndexError, TypeError):
82
+ self.logger.error("Unexpected OpenRouter response structure: %s", data)
83
+ return None
84
+
85
+ if not generated_text:
86
+ return None
87
+
88
+ usage = data.get("usage", {})
89
+ return {
90
+ "model": data.get("model"),
91
+ "response": generated_text,
92
+ "tokens_generated": usage.get("completion_tokens"),
93
+ "total_duration_ms": None,
94
+ "prompt_eval_tokens": usage.get("prompt_tokens"),
95
+ }
96
+
97
+ except Exception as e:
98
+ self.logger.exception("Error in OpenRouterProvider.generate_text: %s", e)
99
+ return None
100
+
101
+ def embed_text(self, text: str, document_type: str = None):
102
+ """OpenRouter does not support embeddings natively — returns None."""
103
+ self.logger.warning("OpenRouterProvider does not support embeddings.")
104
+ return None
105
+
106
+ def construct_prompt(self, prompt: str, role: str):
107
+ return {
108
+ "role": role,
109
+ "content": self.process_text(prompt)
110
+ }
111
+
112
+ def embed_text_batch(self, texts: list[str], batch_size: int = 32):
113
+ """OpenRouter does not support embeddings natively — returns None."""
114
+ self.logger.warning("OpenRouterProvider does not support embeddings.")
115
+ return None
116
+
117
+ def clean_content(self, text: str) -> str:
118
+ text = re.sub(r'\[.*?\]\(.*?\)', '', text)
119
+ text = re.sub(r'\[[^\]]*\]', '', text)
120
+ text = re.sub(r'\n+', '\n', text).strip()
121
+ return text
122
+
123
+ def web_search(self, query: str):
124
+ """
125
+ OpenRouter supports online models (e.g. perplexity/sonar-online) that have
126
+ built-in web search. Route the query through one of those models if available,
127
+ otherwise fall back to a not-supported notice.
128
+ """
129
+ try:
130
+ online_model = os.getenv("OPENROUTER_SEARCH_MODEL", "perplexity/sonar-online")
131
+
132
+ payload = {
133
+ "model": online_model,
134
+ "messages": [{"role": "user", "content": query}],
135
+ "max_tokens": int(self.default_generation_max_output_tokens),
136
+ "temperature": float(self.default_generation_temperature),
137
+ }
138
+
139
+ url = self.url.rstrip("/") + "/chat/completions"
140
+ headers = {
141
+ "Authorization": f"Bearer {self.api_key}",
142
+ "Content-Type": "application/json",
143
+ "HTTP-Referer": os.getenv("OPENROUTER_SITE_URL", "http://localhost"),
144
+ "X-Title": os.getenv("OPENROUTER_APP_NAME", "LLMApp"),
145
+ }
146
+
147
+ resp = requests.post(url, json=payload, headers=headers, timeout=6000)
148
+ if not resp or resp.status_code != 200:
149
+ return {
150
+ "text": "No relevant external results found.",
151
+ "references": []
152
+ }
153
+
154
+ data = resp.json()
155
+
156
+ combined_text = []
157
+ references = set()
158
+
159
+ try:
160
+ text_content = data["choices"][0]["message"]["content"]
161
+ combined_text.append(self.clean_content(text_content))
162
+ except (KeyError, IndexError, TypeError):
163
+ pass
164
+
165
+ # Extract any URLs from the response text
166
+ for found_url in re.findall(r"https?://[^\s)]+", "\n".join(combined_text)):
167
+ references.add(found_url)
168
+
169
+ return {
170
+ "text": "\n\n".join(combined_text[:3]),
171
+ "references": list(references)
172
+ }
173
+
174
+ except Exception as e:
175
+ self.logger.error("OpenRouter web search failed: %s", e)
176
+ return {
177
+ "text": f"OpenRouter search error: {str(e)}",
178
+ "references": []
179
+ }
stores/llm/providers/__init__.py ADDED
File without changes