Headers
+{json.dumps(log["headers"], indent=2)}
+ Body
+{log["body"]}
+ diff --git a/.env b/.env new file mode 100644 index 0000000000000000000000000000000000000000..0eac45f5272ab8266bcfc40060e78b0b34f4a694 --- /dev/null +++ b/.env @@ -0,0 +1,115 @@ +APP_NAME="IntegraRAG" +DEBUG=False +CustomLoaders=False + +# ---------- QDRANT ---------- Choose One +# QDRANT_TYPE="local" +# QDRANT_DOCKER_URL="" +# QDRANT_API_KEY="" + +# QDRANT_TYPE="docker" +# QDRANT_DOCKER_URL="http://localhost:6333/" +# QDRANT_API_KEY="" + +QDRANT_TYPE="cloud" +QDRANT_DOCKER_URL="https://d7e287d8-903d-436c-854c-03cbef9e4edb.us-east4-0.gcp.cloud.qdrant.io" +QDRANT_API_KEY="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIn0.NRbT0QPl7isuBKvdtganh89xa2DeMgKXZ3gSJngexQg" + + +# ---------- REDIS ---------- +REDIS_HOST="rediss://default:gQAAAAAAAS-BAAIncDFiM2E3OGQ1MmU5Zjk0OGM5ODU2ZmMzYzc4NjZjYzdjMHAxNzc2OTc@steady-clam-77697.upstash.io" +REDIS_PORT=6379 + +# ---------- WEBHOOKS ---------- +CALLBACK_URL="https://webhooksite.net/c93aac48-5237-4078-9511-14d778acba2f" +GRADE_WEBHOOK_URL="https://webhooksite.net/c93aac48-5237-4078-9511-14d778acba2f" + +# ---------- BACKENDS ---------- Choose One +#generation +# OLLAMA | COHERE | MISTRAL | GEMINI | HUGGINGFACE | GROQ | OPENROUTER | DEEPSEEK | +#embedding +# OLLAMA | COHERE | MISTRAL | GEMINI | HUGGINGFACE + +# ---------- OLLAMA ---------- +OLLAMA_URL="http://localhost:11434" +# OLLAMA_API_KEY="getAone" +# GENERATION_BACKEND="OLLAMA" +# EMBEDDING_BACKEND="OLLAMA" +# GENERATION_MODEL_ID="deepseek-v3.1:671b-cloud" +# EMBEDDING_MODEL_ID="embeddinggemma:latest" +# EMBEDDING_MODEL_SIZE=768 +# QDRANT_COLLECTION="768_docs" + + +# ---------- COHERE ---------- +COHERE_API_KEY="getAone" +# GENERATION_BACKEND="COHERE" +# EMBEDDING_BACKEND="COHERE" +# GENERATION_MODEL_ID="command-a-03-2025" +# EMBEDDING_MODEL_ID="embed-multilingual-v3.0" +# EMBEDDING_MODEL_SIZE=1024 +# QDRANT_COLLECTION="1024_docs" + + +# ---------- MISTRAL ---------- +MISTRAL_API_KEY="getAone" +# GENERATION_BACKEND="MISTRAL" +# EMBEDDING_BACKEND="MISTRAL" +# GENERATION_MODEL_ID="mistral-small-2603" +# EMBEDDING_MODEL_ID="mistral-embed-2312" +# EMBEDDING_MODEL_SIZE=1024 +# QDRANT_COLLECTION="1024_docs" + +# ---------- GEMINI ---------- +GEMINI_API_KEY="getAone" +GENERATION_BACKEND="GEMINI" +EMBEDDING_BACKEND="GEMINI" +GENERATION_MODEL_ID="gemini-2.5-flash" +EMBEDDING_MODEL_ID="gemini-embedding-001" +EMBEDDING_MODEL_SIZE=768 +QDRANT_COLLECTION="768_docs" + +# ---------- HUGGING FACE ---------- +HF_API_KEY="getAone" +# GENERATION_BACKEND="HUGGINGFACE" +# EMBEDDING_BACKEND="HUGGINGFACE" +# GENERATION_MODEL_ID="Qwen/Qwen2.5-72B-Instruct" +# EMBEDDING_MODEL_ID="google/embeddinggemma-300m" +# EMBEDDING_MODEL_SIZE=768 +# QDRANT_COLLECTION="768_docs" + +# ---------- DEEPSEEK ---------- paid +DEEPSEEK_API_KEY="getAone" +# GENERATION_BACKEND="DEEPSEEK" +# EMBEDDING_BACKEND="COHERE" +# GENERATION_MODEL_ID="deepseek-chat" +# EMBEDDING_MODEL_ID="embed-multilingual-v3.0" +# EMBEDDING_MODEL_SIZE=1024 +# QDRANT_COLLECTION="1024_docs" + + +# ---------- OPENAI ---------- paid +OPENAI_API_KEY="" +OPENAI_API_URL="" + + +# ---------- GROQ ----------not complete +GROQ_API_KEY="" + +# ---------- OPENROUTER ----------not complete +OPENROUTER_API_KEY="" +OPENROUTER_SITE_URL="http://localhost" +OPENROUTER_APP_NAME="IntegraRAG" +OPENROUTER_SEARCH_MODEL="perplexity/sonar-online" + + + +# ---------- DEFAULTS ---------- +INPUT_DAFAULT_MAX_CHARACTERS=2048 +GENERATION_DAFAULT_MAX_TOKENS=1200 +GENERATION_DAFAULT_TEMPERATURE=0.3 + +# ---------- CHUNKING ---------- +CHUNK_SIZE=700 +CHUNK_OVERLAP=150 +CHUNK_METHOD="recursive" \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..580f40192bdfc1a317161044638a16ddc805ab88 --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +# Python virtual environment +venv/ +.env/ +__pycache__/ +*.pyc +*.pyo +*.pyd + +Code_Backups.txt +data/ +# VSCode +.vscode/ +.vs/ + +# OS +.DS_Store +Thumbs.db diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..ebb43d113dbb0a1b71ba9c120bbcf93a6d1be867 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,36 @@ +# ───────────────────────────────────────────── +# IntegraRAG — Production Dockerfile +# Services bundled: FastAPI + Celery worker +# External deps: Redis, Qdrant (cloud/managed) +# ───────────────────────────────────────────── +FROM python:3.11-slim + +# System deps +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + curl \ + libmagic1 \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /app + +# Install Python deps first (layer cache) +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy application source +COPY . . + +# ── Runtime env defaults (override via HF Secrets or docker run -e) ── +ENV PORT=7860 \ + PYTHONUNBUFFERED=1 \ + PYTHONDONTWRITEBYTECODE=1 + +# Hugging Face Spaces exposes port 7860 +EXPOSE 7860 + +# Entrypoint: start Celery worker in background, then FastAPI +COPY docker-entrypoint.sh /docker-entrypoint.sh +RUN chmod +x /docker-entrypoint.sh + +ENTRYPOINT ["/docker-entrypoint.sh"] diff --git a/README.md b/README.md index e6be6398393b4a7e51951dba5ac1dcb512fcea05..6831cbdcb9868c51560a976e42800f837d4ac850 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,27 @@ --- -title: EXAM RAG API -emoji: 💻 -colorFrom: gray +title: IntegraRAG API +emoji: 🧠 +colorFrom: indigo colorTo: blue sdk: docker +app_port: 7860 pinned: false --- -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# IntegraRAG — RAG-Powered Exam & Assistant API + +FastAPI backend with Celery workers for document-based Q&A, exam generation, and AI-graded exam submissions. + + +`conda create -n RAG_API python==3.11` +`conda activate RAG_API` +`pip install -r requirements.txt` + +`docker run -p 6333:6333 qdrant/qdrant` +`docker run -d -p 6379:6379 redis:7` + +# View The .env + +`celery -A celery_app.celery_app worker -P threads --loglevel=info` +`uvicorn main:app --host 0.0.0.0 --port 8030 --reload` +`uvicorn webhook:app --reload` \ No newline at end of file diff --git a/celery_app.py b/celery_app.py new file mode 100644 index 0000000000000000000000000000000000000000..f58ae3e6179418b29b643590e510c5137259a4f9 --- /dev/null +++ b/celery_app.py @@ -0,0 +1,30 @@ +# celery_app.py +from celery import Celery +import redis +from config import get_settings + +celery_app = Celery( + "assistant_worker", + broker=f"{get_settings().REDIS_HOST}:{get_settings().REDIS_PORT}/0", + backend=f"{get_settings().REDIS_HOST}:{get_settings().REDIS_PORT}/1", + include=['generation.ExamAnswer'] +) + +celery_app.conf.update( + task_serializer="json", + accept_content=["json"], + result_serializer="json", + task_track_started=True, + task_time_limit=60*60, +) + +import worker.tasks +from generation.ExamAnswer import grade_exam_task +def clear_redis_backend(): + r = redis.Redis(host=get_settings().REDIS_HOST, port=get_settings().REDIS_PORT, db=1) + r.flushdb() + print("Redis result backend cleared!") + +@celery_app.on_after_configure.connect +def setup(sender, **kwargs): + clear_redis_backend() \ No newline at end of file diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..b81c277e8d5a7a16e0c8a5e76dc6ff44151b24c5 --- /dev/null +++ b/config.py @@ -0,0 +1,64 @@ +from pydantic_settings import BaseSettings +from functools import lru_cache + + +class Settings(BaseSettings): + DEBUG: bool = False + APP_NAME: str + QDRANT_COLLECTION: str = "docs" + + CustomLoaders: bool = None + QDRANT_TYPE: str = "docker" + QDRANT_DOCKER_URL: str = "http://localhost:6333" + QDRANT_API_KEY: str = None + CHUNK_SIZE: int = 1000 + CHUNK_OVERLAP: int = None + CHUNK_METHOD: str = None + GRADE_WEBHOOK_URL: str = None + REDIS_HOST: str = "localhost" + REDIS_PORT: int = 6379 + CALLBACK_URL: str = None + + # ---------- BACKENDS ---------- + GENERATION_BACKEND: str = "OLLAMA" + EMBEDDING_BACKEND: str = "OLLAMA" + + # ---------- API KEYS ---------- + OPENAI_API_KEY: str = None + OPENAI_API_URL: str = None + + COHERE_API_KEY: str = None + + OLLAMA_URL: str = "http://localhost:11434" + OLLAMA_API_KEY: str = None + + MISTRAL_API_KEY: str = None + + GROQ_API_KEY: str = None + + OPENROUTER_API_KEY: str = None + OPENROUTER_SITE_URL: str = "http://localhost" # forwarded as HTTP-Referer + OPENROUTER_APP_NAME: str = "IntegraRAG" # forwarded as X-Title + OPENROUTER_SEARCH_MODEL: str = "perplexity/sonar-online" + + HF_API_KEY: str = None + + DEEPSEEK_API_KEY: str = None + + GEMINI_API_KEY: str = None + + # ---------- MODELS ---------- + GENERATION_MODEL_ID: str = "deepseek-v3.1:671b-cloud" + EMBEDDING_MODEL_ID: str = "embeddinggemma:latest" + EMBEDDING_MODEL_SIZE: int = 768 + INPUT_DAFAULT_MAX_CHARACTERS: int = None + GENERATION_DAFAULT_MAX_TOKENS: int = None + GENERATION_DAFAULT_TEMPERATURE: float = None + + class Config: + env_file = ".env" + + +@lru_cache +def get_settings(): + return Settings() diff --git a/docker-entrypoint.sh b/docker-entrypoint.sh new file mode 100644 index 0000000000000000000000000000000000000000..bea568c2accb7cb80d3673db09416f4dfddd72c0 --- /dev/null +++ b/docker-entrypoint.sh @@ -0,0 +1,14 @@ +#!/bin/bash +set -e + +echo "==> Starting Celery worker in background..." +celery -A celery_app.celery_app worker \ + -P threads \ + --loglevel=info \ + --concurrency=4 & + +echo "==> Starting FastAPI (uvicorn) on port ${PORT:-7860}..." +exec uvicorn main:app \ + --host 0.0.0.0 \ + --port "${PORT:-7860}" \ + --workers 1 diff --git a/generation/AssistantRagGenerator.py b/generation/AssistantRagGenerator.py new file mode 100644 index 0000000000000000000000000000000000000000..2b71918236361b6e9e7a1b94e6e229ba1caa41d1 --- /dev/null +++ b/generation/AssistantRagGenerator.py @@ -0,0 +1,201 @@ +from typing import Any +from pydantic import Field +from langchain_core.language_models import LLM +from langchain_core.runnables import RunnableBranch, RunnableLambda, RunnableParallel +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import PromptTemplate +from stores.llm.LLMProviderFactory import LLMProviderFactory +from config import get_settings + +class ProviderLLMWrapper(LLM): + provider: Any = Field(..., description="The wrapped LLM provider") + def _call(self, prompt: str, stop=None) -> str: + # Calls the underlying model and ensures a string is returned + result = self.provider.generate_text(prompt) + if result is None: + raise ValueError("LLM provider returned None (likely due to timeout or error)") + if isinstance(result, dict): + response = result.get("response") + if response is None: + raise ValueError(f"LLM provider returned dict without 'response' key: {result.keys()}") + return response + if isinstance(result, str): + return result + raise ValueError(f"Unexpected LLM response type: {type(result).__name__}") + @property + def _llm_type(self): + return "custom-provider" + + def get_num_tokens(self, text: str) -> int: + return len(text.split()) + +class AssistantRagGen: + def __init__(self): + config = get_settings() + self.factory = LLMProviderFactory(config) + self.generator = self.factory.create(config.GENERATION_BACKEND) + self.generator.set_generation_model(config.GENERATION_MODEL_ID) + self.llm = ProviderLLMWrapper(provider=self.generator) + self.valid_routes = {"user_info", "site_query", "pdf_query"} + + def build_router_prompt(self, user_prompt: str) -> str: + return f"""You are a query routing classifier. Your sole job is to categorize a user's question into exactly one routing category. + + ## Categories + + | Category | Routes questions about... | + |--------------|-------------------------------------------------------------------------------------------| + | `user_info` | Personal profile, enrolled courses, username, role, learning progress, achievements | + | `site_query` | Platform features, website navigation, rules, policies, FAQs, general platform knowledge | + | `pdf_query` | Document content, uploaded files, PDF search, lesson materials, reading resources | + + ## Examples + + user_info → "What courses am I enrolled in?" + user_info → "What is my current progress in the Python course?" + site_query → "How do I reset my password?" + site_query → "What are the platform's refund policies?" + pdf_query → "What does the document say about recursion?" + pdf_query → "Find me the section on neural networks in the materials" + + ## Decision Rules + + 1. If the question involves the **current user's personal data** → `user_info` + 2. If the question is about **how the platform works** → `site_query` + 3. If the question requires **reading or searching a document** → `pdf_query` + 4. When ambiguous, prefer `pdf_query` over `site_query`, and `user_info` over both. + + ## Output Format + + Respond with a single lowercase word. No punctuation. No explanation. No whitespace. + + Valid outputs: user_info | site_query | pdf_query + + Question: {user_prompt} + """ + + def build_unified_prompt(self, context: str, question: str, conversation_history: str = "", User_Info: str = "") -> str: + return f""" + You are a helpful university assistant. + + Rules: + - Use the provided context FIRST. + - Use conversation history to understand follow-up questions. + - If the question is about the user, use the User_Info and enrolled_courses. + - If the answer is not in the context, say: + "Not found in the provided materials." + Then add: + "From my own information:" and answer briefly. + - Be concise and clear. + + Conversation History: + {conversation_history if conversation_history else "None"} + + User Info: + {User_Info if User_Info else "None"} + + Context: + {context} + + Current Question: + {question} + + Answer: + """ + + def build_user_info_prompt(self, question: str, conversation_history: str = "", User_Info: str = "") -> str: + return f""" + You are a university assistant handling a user account inquiry. + Use the provided User Info and Enrolled Courses to answer the question accurately. + + Conversation History: + {conversation_history if conversation_history else "None"} + + User Info: + {User_Info if User_Info else "None"} + + Current Question: + {question} + + Answer: + """ + + def build_site_query_prompt(self, question: str,context:str="", conversation_history: str = "") -> str: + return f""" + You are a university assistant handling a platform or site-related question. + Provide clear instructions, rules, or general information about how the university platform works. + + Conversation History: + {conversation_history if conversation_history else "None"} + + Current Question: + {question} + + Site Context: + {context if context else "None"} + + Answer: + """ + + def robust_router(self, input_data: dict) -> str: + question = input_data["question"] + attempts = 0 + while attempts < 3: + prompt = self.build_router_prompt(question) + route = self.llm.invoke(prompt).strip().lower() + + if route in self.valid_routes: + return route + attempts += 1 + return "pdf_query" + + def get_chain(self): + router_node = RunnableLambda(self.robust_router) + + user_info_chain = RunnableLambda(lambda x: self.llm.invoke( + self.build_user_info_prompt( + question=x["question"], + conversation_history=x.get("conversation_history", ""), + User_Info=x.get("User_Info", ""), + ) + )) + + site_query_chain = RunnableLambda(lambda x: self.llm.invoke( + self.build_site_query_prompt( + question=x["question"], + context=x.get("context", ""), + conversation_history=x.get("conversation_history", "") + ) + )) + + pdf_query_chain = RunnableLambda(lambda x: self.llm.invoke( + self.build_unified_prompt( + context=x.get("context", "No context provided."), + question=x["question"], + conversation_history=x.get("conversation_history", ""), + User_Info=x.get("User_Info", ""), + ) + )) + + branching_logic = RunnableBranch( + (lambda x: x["topic"] == "user_info", user_info_chain), + (lambda x: x["topic"] == "site_query", site_query_chain), + pdf_query_chain + ) + + full_chain = ( + RunnableParallel({ + "topic": router_node, + # Pass all incoming variables straight through to the branches + "question": lambda x: x["question"], + "context": lambda x: x.get("context", ""), + "conversation_history": lambda x: x.get("conversation_history", ""), + "User_Info": lambda x: x.get("User_Info", ""), + "enrolled_courses": lambda x: x.get("enrolled_courses", "") + }) + | branching_logic + | StrOutputParser() + ) + + return full_chain + diff --git a/generation/ExamAnswer.py b/generation/ExamAnswer.py new file mode 100644 index 0000000000000000000000000000000000000000..af4abe6c86f78a70bed24e6b69af1960219e5670 --- /dev/null +++ b/generation/ExamAnswer.py @@ -0,0 +1,314 @@ +import logging +from datetime import datetime +from typing import List, Dict, Any +from celery import shared_task +import json +import re +import httpx + +from generation.answer_models import (ExamSubmission,ExamResult,StudentAnswer,GradedAnswer,QuestionType) +from indexing.indexingController import IndexingController +from stores.llm.LLMProviderFactory import LLMProviderFactory +from config import get_settings + + +def calculate_grade(percentage: float) -> str: + if percentage >= 90: + return "A" + elif percentage >= 80: + return "B" + elif percentage >= 70: + return "C" + elif percentage >= 60: + return "D" + else: + return "F" + + +logger = logging.getLogger(__name__) + +class ExamGradingService: + def __init__(self, use_ai_for_essays: bool = True): + self.use_ai_for_essays = use_ai_for_essays + + config = get_settings() + + factory = LLMProviderFactory(config) + provider = factory.create(config.GENERATION_BACKEND) + provider.set_generation_model(config.GENERATION_MODEL_ID) + self.llm = provider + + self.semantic_threshold = 0.65 + self.high_confidence = 0.85 + + def grade_submission(self, submission: ExamSubmission) -> ExamResult: + graded_answers: List[GradedAnswer] = [] + total_score = 0 + max_total_score = 0 + + for ans in submission.answers: + correct_answer = None + if ans.metadata: + correct_answer = ans.metadata.get("correct_answer") + + graded = self.grade_answer(ans, correct_answer,submission.course_id) + graded_answers.append(graded) + total_score += graded.score + max_total_score += graded.max_score + + percentage = (total_score / max_total_score) * 100 if max_total_score else 0 + grade = calculate_grade(percentage) + + return ExamResult( + exam_id=submission.exam_id, + student_id=submission.student_id, + student_name=submission.student_name, + graded_answers=graded_answers, + total_score=total_score, + max_total_score=max_total_score, + percentage=percentage, + grade=grade, + feedback_summary="RAG based grading using LLM evaluation", + submission_time=submission.submission_time, + graded_time=datetime.utcnow().isoformat() + ) + + def grade_answer(self, answer: StudentAnswer, correct_answer: Any, course) -> GradedAnswer: + if answer.question_type in [QuestionType.MULTIPLE_CHOICE,QuestionType.TRUE_FALSE]: + student_str = str(answer.student_response).strip().lower() + if answer.question_type == QuestionType.TRUE_FALSE: + if isinstance(correct_answer, bool): + correct_bool = correct_answer + elif isinstance(correct_answer, str): + correct_bool = correct_answer.lower() in ['true', 't', '1', 'yes', 'True'] + else: + correct_bool = bool(correct_answer) + student_bool = student_str in ['true', 't', '1', 'yes'] + is_correct = student_bool == correct_bool + score = answer.max_score if is_correct else 0 + feedback = "Exact match grading" + else: # multiple_choice + correct_str = str(correct_answer).strip().lower() if correct_answer else "" + is_correct = student_str == correct_str + score = answer.max_score if is_correct else 0 + feedback = "Exact match grading" + else: + if self.use_ai_for_essays and correct_answer: + score, feedback = self.ai_semantic_grade( + answer.question_text, + answer.student_response, + correct_answer, + answer.max_score, + course=course + ) + is_correct = score >= (answer.max_score * self.semantic_threshold) + else: + similarity = self.simple_similarity( + answer.student_response, + correct_answer + ) + score = similarity * answer.max_score + is_correct = similarity >= self.semantic_threshold + feedback = f"Similarity score {similarity:.2f}" + + return GradedAnswer( + question_no=answer.question_no, + question_type=answer.question_type, + question_text=answer.question_text, + student_response=answer.student_response, + correct_answer=correct_answer, + score=score, + max_score=answer.max_score, + feedback=feedback, + is_correct=is_correct + ) + + def simple_similarity(self, student: str, correct: str) -> float: + if not student or not correct: + return 0 + student_words = set(student.lower().split()) + correct_words = set(correct.lower().split()) + intersection = student_words.intersection(correct_words) + union = student_words.union(correct_words) + return len(intersection) / len(union) + + def retrieve_context(self, question: str, course:str): + """ + Retrieve relevant context from Qdrant for a given question filtered by course + Args: question: The question text to embed and search for // course: Optional course filter + Returns: String containing concatenated context from top 3 chunks + """ + try: + controller = IndexingController() + embedding = controller.embedder.embed_text(question) + + # Build metadata filters course + filters = [] + if course: + filters.append({ + "field": "course", + "op": "eq", + "value": course, + "clause": "must" + }) + + # Query Qdrant with filters + results = controller.vector_store.query_qdrant(embedding=embedding,filters=filters,top_k=5) + + context = "\n".join(r["content"] for r in results if r.get("content")) + + logger.info(f"Retrieved {len(results)} chunks for question (filtered by course={course})") + return context + + except Exception as e: + logger.error(f"Context retrieval failed: {e}") + return "" + + def build_prompt(self, question, student_answer, correct_answer, context): + return f""" +You are an academic exam grader. + +Question: +{question} + +Correct Answer: +{correct_answer} + +Reference Material: +{context} + +Student Answer: +{student_answer} + +Evaluate the student answer using semantic similarity. +You may slightly use your knowledge if correct answer not in Reference Material. + +Return JSON only: + +{{ +"score": number between 0 and 1, +"feedback": short explanation +}} +""" + + def parse_llm_output(self, text: str): + try: + if isinstance(text, dict): + if 'response' in text: + text = text['response'] + else: + text = str(text) + elif hasattr(text, 'content'): + text = text.content + elif hasattr(text, 'text'): + text = text.text + text = str(text).strip() + if not text: + return 0, "Empty response from LLM" + text = re.sub(r'```json\s*|\s*```', '', text) + try: + data = json.loads(text) + except json.JSONDecodeError: + json_match = re.search(r'\{.*\}', text, re.DOTALL) + if json_match: + data = json.loads(json_match.group()) + else: + raise + + score = float(data.get("score", 0)) + feedback = data.get("feedback", "") + score = max(0, min(score, 1)) + return score, feedback + + except Exception as e: + logger.error(f"Failed to parse LLM output: {e}, text type: {type(text)}") + return 0, "Failed to parse AI grading" + + def ai_semantic_grade(self, question, student, correct, max_score, course): + """ + Grade an answer using AI with context from Qdrant. + Args: question: The question text // student: Student's answer // correct: Correct answer + max_score: Maximum score for this question // course: Optional course for filtering context + Returns: // Tuple of (score, feedback) + """ + try: + # Retrieve context filtered by username and course + context = self.retrieve_context(question, course) + + prompt = self.build_prompt(question,student,correct,context) + + response = self.llm.generate_text(prompt) + + # Log response type for debugging + logger.info(f"Response type: {type(response)}") + + score_ratio, feedback = self.parse_llm_output(response) + score = score_ratio * max_score + + return score, feedback + + except Exception as e: + logger.error(f"AI grading failed: {e}") + # Fallback to simple similarity + similarity = self.simple_similarity(student, correct) + return similarity * max_score, f"Fallback similarity grading: {similarity:.2f}" + +@shared_task +def grade_exam_task(submission_dict: Dict[str, Any]): + submission = None + try: + submission = ExamSubmission(**submission_dict) + service = ExamGradingService() + result = service.grade_submission(submission) + result_dict = result.model_dump() + + # Send webhook with grade only + try: + webhook_url = get_settings().GRADE_WEBHOOK_URL + print(f" Webhook URL: {webhook_url}") + + if webhook_url: + # Create grade-only payload + grade_only_payload = { + "status": "completed", + "exam_id": submission.exam_id, + "student_id": submission.student_id, + "course_id":submission.course_id, + "grade": { + "total_score": result_dict['total_score'], + "max_total_score": result_dict['max_total_score'], + "percentage": result_dict['percentage'], + "grade": result_dict['grade'], + "graded_time": result_dict['graded_time'] + }, + "result" : result_dict, + } + + response = httpx.post( + webhook_url, + json=grade_only_payload, + timeout=30.0 + ) + print(f" Response status: {response.status_code}") + + if response.status_code == 200: + print(" Grade-only webhook sent successfully!") + else: + print(f" Webhook returned status: {response.status_code}") + print(f" Response: {response.text[:200]}") + else: + print("WEBHOOK_URL is empty or not set!") + + except Exception as e: + print(f" Webhook error: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + + print(" Task completed successfully") + return result_dict + + except Exception as e: + print(f" ERROR in task: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + raise \ No newline at end of file diff --git a/generation/ExamRagGenerator.py b/generation/ExamRagGenerator.py new file mode 100644 index 0000000000000000000000000000000000000000..9efc4c5232128a682b3c50c6fecbf00d6245a1ea --- /dev/null +++ b/generation/ExamRagGenerator.py @@ -0,0 +1,460 @@ +from typing import List, Dict +import logging +import json +import re +import math +from json_repair import repair_json +from pydantic import parse_obj_as +from collections import defaultdict +from config import get_settings +from routes.schemas.Exam_Models import * +from stores.llm.LLMProviderFactory import LLMProviderFactory +from generation.AssistantRagGenerator import ProviderLLMWrapper +from generation.prompts import ExamPromptBuilder +from indexing.indexingController import IndexingController + +class ExamService: + MAX_CHUNK_CHARS = 2000 + MAX_TOTAL_CONTEXT = 8000 + MAX_SCORE = 40 + PASS_THRESHOLD = int(MAX_SCORE * 0.8) + MAX_GENERATION_ATTEMPTS = 3 + + def __init__(self): + self.logger = logging.getLogger(__name__) + self._models_initialized = False + self.settings=get_settings() + self._init_models() + self.prompts=ExamPromptBuilder() + self.controller = IndexingController() + self.store = self.controller.vector_store + self.BATCH_SIZE=10 + + def _init_models(self): + if self._models_initialized: + return + factory = LLMProviderFactory(self.settings) + self.generator = factory.create(self.settings.GENERATION_BACKEND) + self.generator.set_generation_model(self.settings.GENERATION_MODEL_ID) + self.embedding_provider = factory.create(self.settings.EMBEDDING_BACKEND) + self.embedding_provider.set_embedding_model( + self.settings.EMBEDDING_MODEL_ID, + self.settings.EMBEDDING_MODEL_SIZE + ) + self.llm = ProviderLLMWrapper(provider=self.generator) + self._models_initialized = True + + def _extract_json(self, text: str) -> dict: + """ + Extract the first valid JSON object from LLM output. Attempts to repair malformed JSON using `repair_json`. + """ + match = re.search(r"\{.*\}", text, re.DOTALL) + if not match: + self.logger.error("No JSON found in LLM response:\n%s", text) + raise ValueError("LLM returned no JSON") + json_str = match.group(0) + # Try to load directly + try: + return json.loads(json_str) + except json.JSONDecodeError: + self.logger.warning("Invalid JSON extracted, attempting repair...") + try: + repaired_str = repair_json(json_str) + return json.loads(repaired_str) + except Exception as e: + self.logger.error("Failed to repair JSON:\n%s\nError: %s", json_str, e) + raise + + def normalize_exam_dict(self, data: dict): + # Normalize difficulty enum + if "difficulty" in data: + diff = data["difficulty"] + if isinstance(diff, str): + if "." in diff: + diff = diff.split(".")[-1] + data["difficulty"] = diff.lower() + # Normalize questions + questions = data.get("questions") + if not isinstance(questions, list): + return data + normalized_questions = [] + for q in questions: + if not isinstance(q, dict): + continue + q.pop("id", None) + q.pop("question_id", None) + q.pop("points", None) + + # normalize type + q_type = q.get("type") + if isinstance(q_type, str): + q_type = q_type.lower().strip() + if q_type == "truefalse": + q_type = "true_false" + q["type"] = q_type + + # normalize question text + if "question" in q: + q["question"] = str(q["question"]).strip() + + # MCQ normalization + if q_type == "mcq": + options = q.get("options") + # dict -> list + if isinstance(options, dict): + options = list(options.values()) + # string -> split into options + elif isinstance(options, str): + parts = re.split(r"[A-D]\)|\n|\r", options) + options = [ + p.strip(" .-") + for p in parts + if p.strip() + ] + # ensure list[str] + if isinstance(options, list): + options = [str(o).strip() for o in options] + else: + options = [] + q["options"] = options + + # normalize correct answer + correct = q.get("correct_answer") + if correct is not None: + correct = str(correct).strip() + q["correct_answer"] = correct + # ensure correct answer exists in options + if correct not in q["options"]: + q["options"].append(correct) + # ensure explanation exists + q.setdefault("explanation", "") + + # True/False normalization + elif q_type == "true_false": + ans = q.get("correct_answer") + if isinstance(ans, str): + ans = ans.lower() + if ans in ["true", "t", "1", "yes"]: + ans = True + elif ans in ["false", "f", "0", "no"]: + ans = False + q["correct_answer"] = ans + q.setdefault("explanation", "") + + # Short Answer normalization + elif q_type == "short_answer": + if "answer" in q: + q["answer"] = str(q["answer"]).strip() + q.setdefault("explanation", "") + + # Essay normalization + elif q_type == "essay": + if "expected_keywords" in q: + keywords = q.pop("expected_keywords") + if isinstance(keywords, list): + q["answer_guidelines"] = ", ".join(keywords) + else: + q["answer_guidelines"] = str(keywords) + q.setdefault("answer_guidelines", "") + + # Code question normalization + elif q_type == "code": + if "solution" in q: + q["solution"] = str(q["solution"]) + q.setdefault("starter_code", None) + q.setdefault("explanation", "") + normalized_questions.append(q) + data["questions"] = normalized_questions + + return data + + def generate_exam(self, request: ExamGenerationRequest, context: str, llm, batch_size: int) -> List[QuestionUnion]: + """ + Generate a batch of questions from the LLM, ensuring valid QuestionUnion objects.Repairs incomplete MCQs automatically. + """ + # Prepare the prompt for the batch + batch_request = request.model_copy() + batch_request.total_questions = batch_size + + prompt = self.prompts.build_exam_generation_prompt(batch_request, context) + raw_text = llm._call(prompt) + + if not raw_text: + raise RuntimeError("LLM generation failed") + + cleaned = re.sub(r"```[a-zA-Z]*|```", "", raw_text).strip() + + try: + exam_dict = self._extract_json(cleaned) + exam_dict = self.normalize_exam_dict(exam_dict) + + questions = exam_dict.get("questions") or [] + questions = questions[:batch_size] + + # Repair incomplete MCQs or missing fields + repaired_questions = [] + for q in questions: + if not isinstance(q, dict): + continue # skip invalid entries + q_type = q.get("type") + if q_type == "mcq": + if not q.get("options"): + self.logger.warning(f"Skipping MCQ with no options: {q}") + continue + if not q.get("correct_answer"): + q["correct_answer"] = q["options"][0] # safe placeholder + repaired_questions.append(q) + + # Convert to Pydantic QuestionUnion objects + questions = parse_obj_as(List[QuestionUnion], repaired_questions) + + self.logger.info( + "Batch requested=%d | received=%d | kept=%d", + batch_size, + len(exam_dict.get("questions", [])), + len(questions), + ) + + except json.JSONDecodeError: + self.logger.error("Invalid JSON from LLM:\n%s", raw_text) + raise + + return questions + + def evaluate_exam(self, request: ExamGenerationRequest, exam: ExamResponse, llm): + prompt = self.prompts.build_exam_evaluation_prompt(request, exam) + raw_text = llm._call(prompt) + + if not raw_text: + raise RuntimeError("Evaluation generation failed") + + cleaned = re.sub(r"```[a-zA-Z]*|```", "", raw_text).strip() + + try: + evaluation_dict = self._extract_json(cleaned) + except json.JSONDecodeError: + self.logger.error("Invalid evaluation JSON:\n%s", raw_text) + raise + + return EvaluationResult.model_validate(evaluation_dict) + + def split_chunks_by_topic_batches(self, exam_chunks, num_batches): + + self.logger.info(f"Topics retrieved: {list(exam_chunks.keys())}") + self.logger.info(f"Number of batches: {num_batches}") + + batches = [[] for _ in range(num_batches)] + + for topic, chunks in exam_chunks.items(): + total_chunks = len(chunks) + self.logger.info(f"Topic '{topic}' -> {total_chunks} chunks distributed across batches") + + for idx, chunk in enumerate(chunks): + batch_index = idx % num_batches + batches[batch_index].append(chunk) + + # Log batch composition + for i, batch in enumerate(batches): + topic_counter = defaultdict(int) + for chunk in batch: + topic = chunk.get("metadata", {}).get("topic", "unknown") + topic_counter[topic] += 1 + self.logger.info(f"Batch {i+1} contains {len(batch)} chunks -> {dict(topic_counter)}") + + return batches + + + def exam_task(self, request_dict: dict) -> ExamResponse: + """ + Generate a full exam using batching, safety break, and validated QuestionUnion questions.Each batch receives a portion of the retrieved chunks. + """ + request = ExamGenerationRequest.model_validate(request_dict) + # Prepare context from knowledge store + topics_with_embeddings = self.prepare_topics_with_embeddings(request.topics) + exam_chunks = self.store.retrieve_for_exam(topics_with_embeddings,request.username,request.course,request.references) + + # Determine number of batches + num_batches = math.ceil(request.total_questions / self.BATCH_SIZE) + self.logger.info(f"Raw exam_chunks structure: {type(exam_chunks)}") + + for k, v in exam_chunks.items(): + self.logger.info(f"Topic={k} | type={type(v)} | len={len(v) if hasattr(v,'__len__') else 'NA'}") + + chunk_batches = self.split_chunks_by_topic_batches(exam_chunks,num_batches) + + feedback_context = "" + + best_exam = None + best_score = 0 + + for attempt in range(self.MAX_GENERATION_ATTEMPTS): + self.logger.info(f"Generating exam attempt {attempt+1}") + + remaining_distribution: Dict[QuestionType, int] = dict(request.question_types_distribution) + all_questions: List[QuestionUnion] = [] + batch_index = 0 + + # Batch generation loop + while len(all_questions) < request.total_questions: + remaining = request.total_questions - len(all_questions) + batch_size = min(self.BATCH_SIZE, remaining) + # Determine batch distribution + batch_distribution: Dict[QuestionType, int] = {} + slots_left = batch_size + + for qtype, count in remaining_distribution.items(): + if slots_left <= 0: + break + + take = min(count, slots_left) + + if take > 0: + batch_distribution[qtype] = take + slots_left -= take + + if not batch_distribution: + break + + batch_request = request.model_copy() + batch_request.total_questions = sum(batch_distribution.values()) + batch_request.question_types_distribution = batch_distribution + + # Select chunk subset for this batch + + chunk_subset = chunk_batches[batch_index % len(chunk_batches)] + self.logger.info(f"\n===== BATCH {batch_index+1} CHUNKS =====") + + for i, chunk in enumerate(chunk_subset): + + meta = chunk.get("metadata", {}) + topic = meta.get("topic", "unknown") + page = meta.get("page", "NA") + + # Try common text keys + text = chunk.get("text") or chunk.get("content") or chunk.get("page_content") or "" + + preview = text[:200].replace("\n", " ") + + self.logger.info( + f"Chunk {i+1} | Topic={topic} | Page={page} | Preview={preview}" + ) + + self.logger.info("=====================================\n") + + batch_index += 1 + + batch_context = self.build_exam_context(chunk_subset) + + if feedback_context: + batch_context += f"\n\nEvaluator Feedback:\n{feedback_context}" + + # Generate questions + + batch_questions = self.generate_exam(batch_request,batch_context,self.llm,batch_request.total_questions) + + # Filter generated questions + for q in batch_questions: + if remaining_distribution.get(q.type, 0) > 0: + all_questions.append(q) + remaining_distribution[q.type] -= 1 + if len(all_questions) >= request.total_questions: + break + + # Build final exam + + exam_dict = { + "exam_id": request.exam_id, + "difficulty": request.difficulty, + "total_questions": request.total_questions, + "expected_distribution": request.question_types_distribution, + "questions": all_questions[:request.total_questions], + } + + try: + exam = ExamResponse.model_validate(exam_dict) + except Exception as e: + self.logger.error(f"Exam validation failed: {e}") + raise + + evaluation = self.evaluate_exam(request, exam, self.llm) + self.logger.info(f"Evaluation score: {evaluation.overall_score}") + + if evaluation.overall_score > best_score: + best_score = evaluation.overall_score + best_exam = exam + + if evaluation.overall_score >= self.PASS_THRESHOLD: + break + + feedback_context = evaluation.feedback + + if best_exam is None: + raise RuntimeError("Exam generation failed after retries") + + return best_exam + + + + def build_exam_context(self, exam_chunks) -> str: + """ + Accepts either: + 1) {topic: [chunks]} + 2) [chunks] + """ + + # Normalize structure + if isinstance(exam_chunks, list): + topic_chunks = defaultdict(list) + + for c in exam_chunks: + topic = c.get("metadata", {}).get("topic", "Unknown") + topic_chunks[topic].append(c) + + exam_chunks = topic_chunks + + context_parts = [] + total_length = 0 + + for topic, chunks in exam_chunks.items(): + + topic_header = f"\n### Topic: {topic}\n" + + if total_length + len(topic_header) > self.MAX_TOTAL_CONTEXT: + break + + context_parts.append(topic_header) + total_length += len(topic_header) + + for c in chunks: + + text = c.get("payload", {}).get("text", "") + source = c.get("metadata", {}).get("source", "") + bookmark = c.get("metadata", {}).get("bookmark_path", "") + + if not isinstance(text, str): + continue + + if len(text) > self.MAX_CHUNK_CHARS: + text = text[:self.MAX_CHUNK_CHARS] + + formatted_chunk = (f"[Source: {source} | Bookmark: {bookmark}]\n{text}\n") + + if total_length + len(formatted_chunk) > self.MAX_TOTAL_CONTEXT: + break + + context_parts.append(formatted_chunk) + total_length += len(formatted_chunk) + + return "\n".join(context_parts) + + + def prepare_topics_with_embeddings(self, topics: List[str]): + results = [] + for topic in topics: + try: + embedding = self.embedding_provider.embed_text(topic) + results.append((topic, embedding)) + except Exception as e: + self.logger.warning(f"Embedding failed for topic '{topic}': {e}") + self.logger.info(f"Prepared {len(results)} topic embeddings") + return results diff --git a/generation/__init__.py b/generation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/generation/answer_models.py b/generation/answer_models.py new file mode 100644 index 0000000000000000000000000000000000000000..c7eaaaa415a6b3df97717b9f531fc656b6dd6bd0 --- /dev/null +++ b/generation/answer_models.py @@ -0,0 +1,51 @@ +from pydantic import BaseModel +from typing import List, Optional, Dict, Any, Union +from enum import Enum + +class QuestionType(str, Enum): + MULTIPLE_CHOICE = "multiple_choice" + TRUE_FALSE = "true_false" + SHORT_ANSWER = "short_answer" + ESSAY = "essay" + CODE = "code" + +class StudentAnswer(BaseModel): + question_no: int + question_type: QuestionType + question_text: str + student_response: str + max_score: float = 1.0 + metadata: Optional[Dict[str, Any]] = {} + +class GradedAnswer(BaseModel): + question_no: int + question_type: QuestionType + question_text: str + student_response: str + correct_answer: Optional[Any] + score: float + max_score: float + feedback: str + is_correct: bool + +class ExamSubmission(BaseModel): + exam_id: str + course_id: str + student_id: str + student_name: Optional[str] + answers: List[StudentAnswer] + submission_time: str + metadata: Optional[Dict[str, Any]] = {} + +class ExamResult(BaseModel): + exam_id: str + student_id: str + student_name: Optional[str] + graded_answers: List[GradedAnswer] + total_score: float + max_total_score: float + percentage: float + grade: Optional[str] + feedback_summary: Optional[str] + submission_time: str + graded_time: str \ No newline at end of file diff --git a/generation/parsing_utils.py b/generation/parsing_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9913f87eb55f5cc7cde70c7580ee5845936104de --- /dev/null +++ b/generation/parsing_utils.py @@ -0,0 +1,51 @@ +import json +import logging +from typing import Any, Dict, Optional + +logger = logging.getLogger("ExamGraph") + +def safe_parse(parser_obj, text: str, question_no: int) -> Optional[Dict[str, Any]]: + if not text or text.strip() in ("null", "None", ""): + logger.warning(f"[Parse] q{question_no}: empty/null response") + return None + + last_error = None + + # Try direct parse + try: + result = parser_obj.parse(text) + return result.model_dump() if hasattr(result, "model_dump") else result + except Exception as e: + last_error = e + logger.debug(f"[Parse] q{question_no}: direct parse failed, trying extraction") + + # Try to extract JSON from text (LLM may have wrapped it in prose) + try: + # look for {...} pattern + start = text.rfind("{") + end = text.rfind("}") + 1 + if start >= 0 and end > start: + json_str = text[start:end] + json_obj = json.loads(json_str) + result = parser_obj.parse(json.dumps(json_obj)) + return result.model_dump() if hasattr(result, "model_dump") else result + except Exception as e: + last_error = e + logger.debug(f"[Parse] q{question_no}: json extraction failed") + + # Last resort: if it looks like partial JSON, mark for regen + error_msg = str(last_error) if last_error else "unknown" + logger.error(f"[Parse] q{question_no}: failed all attempts: {error_msg}") + return None + +def categorize_error(error_str: str) -> str: + err = error_str.lower() + if "timeout" in err: + return "timeout" + elif "json" in err or "invalid" in err: + return "invalid_json" + elif "field required" in err or "missing" in err: + return "missing_field" + elif "none" in err or "null" in err: + return "null" + return "unknown" diff --git a/generation/prompts.py b/generation/prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4c09b887d521ae5b985318767495dc836085e3 --- /dev/null +++ b/generation/prompts.py @@ -0,0 +1,250 @@ +from generation.ExamRagGenerator import ExamGenerationRequest, ExamResponse +import json + +class ExamPromptBuilder: + MAX_SCORE = 40 + + def build_exam_generation_prompt(self,request: ExamGenerationRequest,context: str) -> str: + distribution = { + q_type.value: count + for q_type, count in request.question_types_distribution.items() + } + + return f""" + You are an automated exam generation system. + + Your job is to produce a structured exam strictly following the schema below. + + ---------------------------------------------------- + CRITICAL OUTPUT RULES + ---------------------------------------------------- + + You MUST return ONLY a valid JSON object. + + Do NOT include: + + - explanations + - markdown + - comments + - code blocks + - text before or after the JSON + + The response MUST start with {{ and end with }}. + + If the output is not valid JSON the result will be rejected. + + ---------------------------------------------------- + ENUM VALUES (STRICT) + ---------------------------------------------------- + + difficulty must be exactly one of: + + easy + medium + hard + + question type must be exactly one of: + + mcq + true_false + short_answer + essay + code + + ---------------------------------------------------- + EXAM REQUIREMENTS + ---------------------------------------------------- + + course: {request.course} + + difficulty: {request.difficulty.value} + + total_questions: {request.total_questions} + + question_types_distribution: + {json.dumps(distribution)} + + You MUST generate exactly: + + {json.dumps(distribution)} + + Example: + + {{ + "mcq": 3, + "essay": 2 + }} + + means exactly: + 3 mcq questions + 2 essay questions + + ---------------------------------------------------- + CONTEXT + ---------------------------------------------------- + + Use ONLY the information from this context when creating questions. + + {context} + + ---------------------------------------------------- + QUESTION RULES + ---------------------------------------------------- + + MCQ QUESTIONS + + - must contain exactly 4 options + - options must be plain text + - correct_answer must match one option EXACTLY + - do NOT use letters like A/B/C/D + - do NOT include numbering inside options + + Example: + + {{ + "type": "mcq", + "question": "What is 2 + 2?", + "options": ["1","2","3","4"], + "correct_answer": "4", + "explanation": "2 + 2 equals 4" + }} + + ---------------------------------------------------- + + TRUE/FALSE QUESTIONS + + correct_answer must be boolean. + + Example: + + {{ + "type": "true_false", + "question": "The Earth revolves around the Sun.", + "correct_answer": true, + "explanation": "Astronomy confirms this." + }} + + ---------------------------------------------------- + + SHORT ANSWER QUESTIONS + + Example: + + {{ + "type": "short_answer", + "question": "Define photosynthesis.", + "answer": "Process where plants convert light into chemical energy", + "explanation": "Occurs in chloroplasts using sunlight" + }} + + ---------------------------------------------------- + + ESSAY QUESTIONS + + Example: + + {{ + "type": "essay", + "question": "Explain Newton's First Law.", + "answer": "Newton's First Law states that an object will remain at rest or continue moving in a straight line at constant velocity unless acted upon by an external force. This property is called inertia. For example, a book on a table stays at rest until someone pushes it, and a moving car continues moving until friction or braking stops it.", + "answer_guidelines": "Describe inertia and provide examples" + }} + + ---------------------------------------------------- + + CODE QUESTIONS + + Rules: + + starter_code must be either a string OR null. + Never output the string "None". + + Example: + + {{ + "type": "code", + "question": "Write a Python function to compute factorial.", + "language": "c", + "starter_code": "def factorial(n):", + "solution": "def factorial(n): return 1 if n<=1 else n*factorial(n-1)", + "explanation": "Uses recursion" + }} + + ---------------------------------------------------- + IMPORTANT RESTRICTIONS + ---------------------------------------------------- + + Do NOT output: + + LaTeX + math formulas + markdown + additional fields + + Use plain text only. + + ---------------------------------------------------- + FINAL JSON STRUCTURE + ---------------------------------------------------- + + {{ + "exam_id": "{request.exam_id}", + "difficulty": "{request.difficulty.value}", + "total_questions": {request.total_questions}, + "expected_distribution": {json.dumps(distribution)}, + "questions": [] + }} + + Fill the questions array with the generated questions. + + ---------------------------------------------------- + + Return ONLY the JSON object. + """ + + def build_exam_evaluation_prompt(self,request: ExamGenerationRequest,exam: ExamResponse) -> str: + + exam_json = exam.model_dump_json() + + return f""" +You are an exam quality evaluator. + +-------------------------------- +OUTPUT RULES +-------------------------------- +1. Output MUST be valid JSON. +2. Do NOT include markdown. +3. Do NOT include reasoning outside JSON. +4. Output ONLY the JSON object. +5. JSON must start with {{ and end with }}. + +-------------------------------- +SCORING RANGE +-------------------------------- +0 to {self.MAX_SCORE} + +-------------------------------- +EVALUATION CRITERIA +-------------------------------- +1. Relevance of questions to the topics +2. Correct distribution of question types +3. Clarity and wording of questions +4. Difficulty consistency +5. Correctness of answers + +-------------------------------- +EXAM TO EVALUATE +-------------------------------- +{exam_json} + +-------------------------------- +OUTPUT FORMAT +-------------------------------- + +{{ +"overall_score": integer between 0 and {self.MAX_SCORE}, +"feedback": "short explanation of issues if any" +}} + +Return ONLY JSON. +""" diff --git a/indexing/indexingController.py b/indexing/indexingController.py new file mode 100644 index 0000000000000000000000000000000000000000..d0b0c6cd5edd38c0858a19eab2bf2344a44cca73 --- /dev/null +++ b/indexing/indexingController.py @@ -0,0 +1,111 @@ +from stores.llm.LLMProviderFactory import LLMProviderFactory +from stores.vector_store.Qdrant import QdrantStore + +from ingestion.loaders.File_loader import load_file +from ingestion.chunkers.recursive_chunker import recursive_chunk +from ingestion.pdf_outline import extract_pdf_outline, build_page_bookmark_map , recursive_chunk_with_pages +from ingestion.loaders.pdf_loader import load_pdf_with_pages + +from config import get_settings + +import os +from qdrant_client import QdrantClient , models + +class IndexingController: + def __init__(self): + config = get_settings() + self.factory = LLMProviderFactory(config) + self.embedder = self.factory.create(config.EMBEDDING_BACKEND) + self.embedder.set_embedding_model(config.EMBEDDING_MODEL_ID, config.EMBEDDING_MODEL_SIZE) + if config.QDRANT_TYPE == "cloud": + self.vector_store_client = QdrantClient(url=config.QDRANT_DOCKER_URL,api_key=config.QDRANT_API_KEY,timeout=120) + elif config.QDRANT_TYPE == "docker": + self.vector_store_client = QdrantClient(url=config.QDRANT_DOCKER_URL,timeout=120) + elif config.QDRANT_TYPE == "local": + self.vector_store_client = QdrantClient(path="data/qdrant",prefer_grpc=False,timeout=120) + + string_fields = ["metadata.username", "metadata.source", "metadata.course","metadata.bookmark_path"] + + if not self.vector_store_client.collection_exists(collection_name=get_settings().QDRANT_COLLECTION): + # 2. Create the collection if it doesn't + self.vector_store_client.create_collection( + collection_name=get_settings().QDRANT_COLLECTION, + vectors_config=models.VectorParams( + size=get_settings().EMBEDDING_MODEL_SIZE, + distance=models.Distance.COSINE + ), + ) + + for field in string_fields: + self.vector_store_client.create_payload_index( + collection_name=get_settings().QDRANT_COLLECTION, + field_name=field, + field_schema=models.KeywordIndexParams( + type=models.KeywordIndexType.KEYWORD + ) + ) + + self.vector_store= QdrantStore(self.vector_store_client,config.QDRANT_COLLECTION, config.EMBEDDING_MODEL_SIZE) + + def embed_chunks(self, chunks): + return self.embedder.embed_text_batch(chunks) + + def process_file(self,file_path, original_filename, username=None, course=None): + file_name = os.path.basename(file_path) + ext = os.path.splitext(file_path)[1].lower() + + bookmark_map = {} + + if ext == ".pdf": + outline , total_pages= extract_pdf_outline(file_path) + bookmark_map = build_page_bookmark_map(outline , total_pages) + + pages = load_pdf_with_pages(file_path) + chunks = recursive_chunk_with_pages(pages) + + else: + text = load_file(file_path) + if isinstance(text, list): + text = " ".join([doc.page_content for doc in text]) + chunks_text = recursive_chunk(text) + chunks = [{"text": c, "page": None} for c in chunks_text] + + embeddings = self.embed_chunks([c["text"] for c in chunks]) + + valid_embs = [] + valid_payloads = [] + + for idx, (chunk_obj, emb) in enumerate(zip(chunks, embeddings)): + if emb is not None: + page = chunk_obj["page"] + bookmark_path = bookmark_map.get(page, []) + + valid_embs.append(emb) + valid_payloads.append({ + "content": chunk_obj["text"], + "metadata": { + "source": original_filename, + "chunk_index": idx, + "total_chunks": len(chunks), + "username": username, + "course": course, + "page": page, + "bookmark_path": bookmark_path, + } + } + ) + print(f"[DEBUG] Prepared payload for chunk {idx}: page={page}, bookmark_path={bookmark_path}") + + self.vector_store.upsert_embeddings( + self.vector_store_client, + get_settings().QDRANT_COLLECTION, + valid_embs, + valid_payloads + ) + print(f"[INFO] Stored {len(valid_embs)} embeddings for file '{file_name}'.") + + return { + "num_chunks": len(chunks), + "chunks": chunks, + "embeddings": embeddings + } diff --git a/ingestion/chunkers/__init__.py b/ingestion/chunkers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ingestion/chunkers/fixed_chunker.py b/ingestion/chunkers/fixed_chunker.py new file mode 100644 index 0000000000000000000000000000000000000000..4b3000ab64429b8d09b4ceb5a295293838c21600 --- /dev/null +++ b/ingestion/chunkers/fixed_chunker.py @@ -0,0 +1,10 @@ +from langchain_text_splitters import CharacterTextSplitter +from config import get_settings + +def fixed_chunk(text): + splitter = CharacterTextSplitter( + chunk_size=get_settings().CHUNK_SIZE, + chunk_overlap=get_settings().CHUNK_OVERLAP + ) + chunks = splitter.split_text(text) + return chunks diff --git a/ingestion/chunkers/recursive_chunker.py b/ingestion/chunkers/recursive_chunker.py new file mode 100644 index 0000000000000000000000000000000000000000..46587c16ee034a78051f4ca15d2726370734ba0f --- /dev/null +++ b/ingestion/chunkers/recursive_chunker.py @@ -0,0 +1,9 @@ +from langchain_text_splitters import RecursiveCharacterTextSplitter +from config import get_settings + +def recursive_chunk(text): + splitter = RecursiveCharacterTextSplitter( + chunk_size=get_settings().CHUNK_SIZE, + chunk_overlap=get_settings().CHUNK_OVERLAP, + ) + return splitter.split_text(text) diff --git a/ingestion/loaders/File_loader.py b/ingestion/loaders/File_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..f79b903dd8e6014eac37cebd641dc3a0f41954e8 --- /dev/null +++ b/ingestion/loaders/File_loader.py @@ -0,0 +1,57 @@ +from config import get_settings +import os + +def get_file_extension(file_id: str): + return os.path.splitext(file_id)[-1] + + +def load_file(file_path: str): + if get_settings().CustomLoaders==True: + from ingestion.loaders.pdf_loader import load_pdf + from ingestion.loaders.txt_loader import load_txt + from ingestion.loaders.md_loader import load_md + from ingestion.loaders.docx_loader import load_docx + + + #Dispatcher + + ext = os.path.splitext(file_path)[1].lower() + + if ext == ".pdf": + docs = load_pdf(file_path) + elif ext == ".docx": + docs = load_docx(file_path) + elif ext == ".md": + docs = load_md(file_path) + elif ext == ".txt": + docs = load_txt(file_path) + else: + print(f"Unsupported file type: {ext}") + return [] + + # Return list of Document objects as-is + return docs + + + elif get_settings().CustomLoaders==False: + + from langchain_community.document_loaders import ( + TextLoader, + Docx2txtLoader, + UnstructuredMarkdownLoader, + PyMuPDFLoader, + ) + + + extension = get_file_extension(file_path) + + if extension == ".txt": + return TextLoader(file_path, encoding="utf8").load() + elif extension == ".docx": + return Docx2txtLoader(file_path).load() + elif extension == ".md": + return UnstructuredMarkdownLoader(file_path).load() + elif extension in [".pdf"]: + return PyMuPDFLoader(file_path).load() + else: + raise ValueError(f"Unsupported file extension: {extension}") diff --git a/ingestion/loaders/__init__.py b/ingestion/loaders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/ingestion/loaders/__init__.py @@ -0,0 +1 @@ + diff --git a/ingestion/loaders/docx_loader.py b/ingestion/loaders/docx_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..20b6112c2185a3383618f73816f24d0db235e7e2 --- /dev/null +++ b/ingestion/loaders/docx_loader.py @@ -0,0 +1,89 @@ +import os +from typing import List +from langchain_core.documents import Document +from docx import Document as DocxDocument +from docx.oxml.table import CT_Tbl +from docx.oxml.text.paragraph import CT_P +from ingestion.loaders.normalization import normalize_text + +def table_to_text(table) -> str: + """Convert DOCX table to plain, readable text without numeric headers.""" + data = [] + try: + for row in table.rows: + row_data = [normalize_text(cell.text) for cell in row.cells] + if any(row_data): # skip empty rows + data.append(row_data) + + if not data: + return "" + + # Format as a readable markdown-like table instead of CSV with numbers + return "\n".join([" | ".join(row) for row in data]) + + except Exception as e: + print(f"Error converting table to text: {e}") + return "" + + + + +def load_docx(file_path: str) -> List[Document]: + """Load DOCX file safely, preserving tables and skipping corrupted sections.""" + docs = [] + + if not os.path.exists(file_path): + print(f"File not found: {file_path}") + return [] + + try: + doc = DocxDocument(file_path) + except Exception as e: + print(f"Failed to open DOCX ({file_path}): {e}") + return [] + + try: + body_elements = list(doc.element.body) + paragraph_iter = iter(doc.paragraphs) + table_iter = iter(doc.tables) + + for element in body_elements: + if isinstance(element, CT_P): + try: + para = next(paragraph_iter) + cleaned = normalize_text(para.text) + if cleaned: + docs.append( + Document( + page_content=cleaned, + metadata={"source": file_path, "type": "text"}, + ) + ) + + except StopIteration: + continue + except Exception as e: + print(f"Error reading paragraph: {e}") + continue + elif isinstance(element, CT_Tbl): + try: + table = next(table_iter) + table_text = table_to_text(table) + if table_text: + docs.append( + Document( + page_content=table_text, + metadata={"source": file_path, "type": "table"}, + ) + ) + except StopIteration: + continue + except Exception as e: + print(f"Error reading table: {e}") + continue + + except Exception as e: + print(f"[WARN] Error processing DOCX ({file_path}): {e}") + return [] + + return docs diff --git a/ingestion/loaders/md_loader.py b/ingestion/loaders/md_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..ed27c8103d630b75eb81ebc908d20e662afdec8a --- /dev/null +++ b/ingestion/loaders/md_loader.py @@ -0,0 +1,48 @@ +import os +import re +from typing import List +from langchain_core.documents import Document +from ingestion.loaders.normalization import normalize_text + + +def load_md(file_path: str) -> List[Document]: + """Load Markdown safely, preserving inline tables and skipping unreadable sections.""" + if not os.path.exists(file_path): + print(f"File not found: {file_path}") + return [] + + text = "" + try: + with open(file_path, "r", encoding="utf-8") as f: + text = f.read() + except UnicodeDecodeError: + try: + with open(file_path, "r", encoding="latin-1") as f: + text = f.read() + except Exception as e: + print(f"Failed to read Markdown file ({file_path}): {e}") + return [] + except Exception as e: + print(f"Could not open Markdown file ({file_path}): {e}") + return [] + + docs = [] + try: + # Split into segments alternating between text and tables + parts = re.split(r"((?:\|.*\|\n)+)", text) + for part in parts: + if not part.strip(): + continue + + # Detect if segment is a table + content_type = "table" if re.match(r"(?:\|.*\|\n)+", part) else "text" + + # Clean markdown formatting but keep structure + cleaned = normalize_text(re.sub(r'(```.*?```|`.*?`|\*\*|__|#)', '', part, flags=re.DOTALL)) + if cleaned: + docs.append(Document(page_content=cleaned, metadata={"source": file_path, "type": content_type})) + except Exception as e: + print(f"Error parsing Markdown file ({file_path}): {e}") + return [] + + return docs diff --git a/ingestion/loaders/normalization.py b/ingestion/loaders/normalization.py new file mode 100644 index 0000000000000000000000000000000000000000..42602577045d4990bc04d9c14cd25505b5db9132 --- /dev/null +++ b/ingestion/loaders/normalization.py @@ -0,0 +1,35 @@ +import re + + +def normalize_text(text: str) -> str: + """Clean and normalize extracted text from any format (PDF/DOCX/MD/TXT).""" + if not text: + return "" + + # Replace common PDF CID artifacts like (cid:1234) + text = re.sub(r'\(cid:\d+\)', '', text) + + # Replace newlines/tabs with spaces + text = text.replace('\n', ' ').replace('\t', ' ') + + # Remove emojis and pictographs + emoji_pattern = re.compile( + "[" + "\U0001F600-\U0001F64F" # emoticons + "\U0001F300-\U0001F5FF" # symbols & pictographs + "\U0001F680-\U0001F6FF" # transport & map + "\U0001F1E0-\U0001F1FF" # flags + "\U00002500-\U00002BEF" + "\U00002700-\U000027BF" + "\U0001F900-\U0001F9FF" + "\U0001FA70-\U0001FAFF" + "\U00002600-\U000026FF" + "\U00002B00-\U00002BFF" + "]+", flags=re.UNICODE + ) + text = emoji_pattern.sub("", text) + + # Collapse multiple spaces + text = re.sub(r'\s+', ' ', text) + + return text.strip() diff --git a/ingestion/loaders/pdf_loader.py b/ingestion/loaders/pdf_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..353122d0d9350cd4e0b85416b855e011aeb71936 --- /dev/null +++ b/ingestion/loaders/pdf_loader.py @@ -0,0 +1,66 @@ +import os +from langchain_core.documents import Document +import pdfplumber +from ingestion.loaders.normalization import normalize_text + +def load_pdf(file_path: str): + documents = [] + # Check if file exists + if not os.path.exists(file_path): + raise FileNotFoundError(f"File not found: {file_path}") + + try: + with pdfplumber.open(file_path) as pdf: + for page_num, page in enumerate(pdf.pages, start=1): + try: + text = page.extract_text() or "" + text = normalize_text(text) + tables = page.extract_tables() or [] + + # Reconstruct page text with tables preserved in order + page_content = text.strip() + for t_idx, table in enumerate(tables, start=1): + table_text = "\n".join( + ["\t".join(cell if cell else "" for cell in row) for row in table] + ) + table_text = normalize_text(table_text) + page_content += f"\n\n=== Table {t_idx} (Page {page_num}) ===\n{table_text}" + + # Append as LangChain Document + documents.append( + Document( + page_content=page_content, + metadata={ + "source": os.path.basename(file_path), + "page_number": page_num, + }, + ) + ) + except Exception as e: + print(f"Error extracting page {page_num}: {e}") + continue # Skip corrupted pages, process others + + except Exception as e: + print(f"Failed to open or read PDF file: {file_path}") + print(f"Error: {e}") + return [] # Return empty list instead of crashing + + return documents + + + + + +def load_pdf_with_pages(file_path: str): + import fitz + doc = fitz.open(file_path) + pages = [] + + for i, page in enumerate(doc): + pages.append({ + "page": i + 1, + "text": page.get_text() + }) + + return pages + diff --git a/ingestion/loaders/txt_loader.py b/ingestion/loaders/txt_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..d64330bbdb043e1a8b41f46031a34f49ce0dec6b --- /dev/null +++ b/ingestion/loaders/txt_loader.py @@ -0,0 +1,38 @@ +import os +from typing import List +from langchain_core.documents import Document +from ingestion.loaders.normalization import normalize_text + +def load_txt(file_path: str) -> List[Document]: + """Load plain text file safely, handling encoding issues.""" + docs = [] + + if not os.path.exists(file_path): + print(f"File not found: {file_path}") + return docs + + text = "" + try: + with open(file_path, "r", encoding="utf-8") as f: + text = f.read() + except UnicodeDecodeError: + try: + with open(file_path, "r", encoding="latin-1") as f: + text = f.read() + except Exception as e: + print(f"Failed to read text file ({file_path}): {e}") + return docs + except Exception as e: + print(f"Could not open file ({file_path}): {e}") + return docs + + try: + cleaned = normalize_text(text) + if cleaned: + docs.append( + Document(page_content=cleaned, metadata={"source": file_path, "type": "text"}) + ) + except Exception as e: + print(f"Error processing text file ({file_path}): {e}") + + return docs diff --git a/ingestion/pdf_outline.py b/ingestion/pdf_outline.py new file mode 100644 index 0000000000000000000000000000000000000000..6f30593f7f4f37accd952388fca522eaf9410a2d --- /dev/null +++ b/ingestion/pdf_outline.py @@ -0,0 +1,62 @@ +import fitz +from config import get_settings +from langchain_text_splitters import RecursiveCharacterTextSplitter + +def extract_pdf_outline(pdf_path: str): + doc = fitz.open(pdf_path) + toc = doc.get_toc(simple=False) + total_pages = doc.page_count + + outline = [] + stack = [] + for level, title, page, *_ in toc: + while stack and stack[-1]["level"] >= level: + stack.pop() + node = {"level": level, "title": title, "page": page, "children": []} + if stack: + stack[-1]["children"].append(node) + else: + outline.append(node) + stack.append(node) + + doc.close() + return outline , total_pages + +def build_page_bookmark_map(outline_tree, total_pages: int): + explicit_map = {} + + def walk(node, path): + current_path = path + [node["title"]] + explicit_map[node["page"]] = current_path + for child in node["children"]: + walk(child, current_path) + + for root in outline_tree: + walk(root, []) + + page_map = {} + last_known_path = [] + + for page_num in range(1, total_pages + 1): + if page_num in explicit_map: + last_known_path = explicit_map[page_num] + page_map[page_num] = last_known_path # carries forward last bookmark + + return page_map + +def recursive_chunk_with_pages(pages): + splitter = RecursiveCharacterTextSplitter( + chunk_size=get_settings().CHUNK_SIZE, + chunk_overlap=get_settings().CHUNK_OVERLAP, + ) + + chunks = [] + for p in pages: + page_chunks = splitter.split_text(p["text"]) + for c in page_chunks: + chunks.append({ + "text": c, + "page": p["page"] + }) + + return chunks diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..70ef0042456848b202b146f7409845e856c55cb1 --- /dev/null +++ b/main.py @@ -0,0 +1,20 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from routes.base import base_router +from routes.assisstant_rag import assisstant_router +from routes.exam_router import exam_router +from routes.exam_grading_router import grading_router + +app = FastAPI() +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +app.include_router(base_router) +app.include_router(assisstant_router) +app.include_router(exam_router) +app.include_router(grading_router) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f8b4e9ede07457a262f017cefbeffb3cb267ae54 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +fastapi==0.120.0 +uvicorn==0.38.0 +python-dotenv==1.2.1 +pdfplumber==0.11.7 +python-docx==1.2.0 +pandas==2.3.3 +langchain==1.0.2 +unstructured==0.18.15 +PyMuPDF==1.26.5 +docx2txt==0.9 +Markdown==3.9 +python-multipart==0.0.20 +cohere==5.5.8 +openai==1.35.13 +qdrant-client== 1.16.1 +httpx==0.28.1 +redis==7.2.0 +celery==5.6.2 +json_repair==0.58.5 \ No newline at end of file diff --git a/routes/__init__.py b/routes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/routes/assisstant_rag.py b/routes/assisstant_rag.py new file mode 100644 index 0000000000000000000000000000000000000000..e9cdd792e3760896e5a1baa17413998f144cb22d --- /dev/null +++ b/routes/assisstant_rag.py @@ -0,0 +1,165 @@ +from fastapi import APIRouter , UploadFile, File +from routes.schemas.Requests_Models import ChatRequest +from generation.AssistantRagGenerator import AssistantRagGen +from indexing.indexingController import IndexingController +from uuid import uuid4 +from worker.tasks import process_file_task +from celery.result import AsyncResult +from celery_app import celery_app + +assisstant_router = APIRouter(tags=["assistant_rag"]) + +@assisstant_router.get("/jobs/{job_id}") +def get_job_status(job_id: str): + result = AsyncResult(job_id, app=celery_app) + if result.state == "PENDING": + return {"job_id": job_id,"state": result.state,"message": "Job is waiting in queue",} + + if result.state == "STARTED": + return { + "job_id": job_id, + "state": result.state, + "message": "Job is currently processing", + } + + if result.state == "SUCCESS": + return { + "job_id": job_id, + "state": result.state, + "result": result.result, + } + + if result.state == "FAILURE": + return { + "job_id": job_id, + "state": result.state, + "error": str(result.result), + } + + return { + "job_id": job_id, + "state": result.state, + } + +@assisstant_router.post("/process-file") +async def process_file_endpoint(course: str , username: str , file: UploadFile = File(...)): + job_id = uuid4().hex + temp_path = f"./temp_{job_id}_{file.filename}" + with open(temp_path, "wb") as f: + f.write(await file.read()) + task = process_file_task.delay(temp_path, file.filename, username, course) + return { + "job_id": task.id, + "filename": file.filename, + "status": "queued", + } + +@assisstant_router.post("/chat/complete") +async def chat_complete_endpoint(request: ChatRequest): + indexing_controller = IndexingController() + rag_gen = AssistantRagGen() + user_query = request.prompt if request.prompt else "no question provided" + route = rag_gen.robust_router({"question": user_query}) + + results = [] + context_text = "" + filters = [] + + # Kda Kda pdf :) + if request.source_file or request.bookmark: + if request.bookmark and not request.source_file: + request.bookmark=None + route = "pdf_query" + + if route == "user_info": + if request.role == "instructor" or request.role == "admin": + context_text = ( + f"User Profile Info: {request.user_info.model_dump()}\n" + f"Role: {request.role}\n" + f"Username: {request.username}" + ) + elif request.role == "student": + request.user_info=request.user_info.copy(update={"instructor_owned_files": None}) + context_text = ( + f"User Profile Info: {request.user_info.model_dump()}\n" + f"Role: {request.role}\n" + f"Username: {request.username}" + ) + + elif route == "site_query": + filters = [ + {"field": "course", "op": "eq", "value": "Instructions", "clause": "must"}, + {"field": "username", "op": "eq", "value": "ADMIN", "clause": "must"} + ] + embedding = indexing_controller.embedder.embed_text(user_query) + results = indexing_controller.vector_store.query_qdrant( + filters=filters, + embedding=embedding, + top_k=request.top_k + ) + + elif route == "pdf_query": + if request.role == "student": + enrolled = request.user_info.courses or [] + print(f"[DEBUG] Student {request.username} is enrolled in courses: {enrolled}") + filters.append({"field": "course", "op": "in", "value": enrolled, "clause": "must"}) + + elif request.role == "instructor": + owned = request.user_info.courses + # if owned == []: + # owned = indexing_controller.vector_store.all_user_files_bookmarks(request.username) + # owned = owned.keys() + print(f"[DEBUG] Instructor {request.username} owns courses/files: {owned}") + filters.append({"field": "course", "op": "in", "value": owned, "clause": "must"}) + + if request.source_file: + filters.append({"field": "source", "op": "eq", "value": request.source_file, "clause": "must"}) + + if request.bookmark: + filters.append({"field": "bookmark_path", "op": "text", "value": request.bookmark, "clause": "must"}) + + embedding = indexing_controller.embedder.embed_text(user_query) + results = indexing_controller.vector_store.query_qdrant( + filters=filters, + embedding=embedding, + top_k=request.top_k + ) + + if not context_text and results: + context_text = "\n\n".join([r["content"] for r in results if r.get("content")]) + + history_str = "\n".join( + f"Human: {turn.Human_msg}\nAssistant: {turn.LLM_response}" + for turn in request.history + ) if request.history else "None" + + if route == "user_info": + final_prompt = rag_gen.build_user_info_prompt( + question=user_query, + conversation_history=history_str, + User_Info=str(request.user_info.model_dump()), + ) + elif route == "site_query": + final_prompt = rag_gen.build_site_query_prompt( + question=user_query, + context=context_text, + conversation_history=history_str + ) + else: + final_prompt = rag_gen.build_unified_prompt( + context=context_text, + question=user_query, + conversation_history=history_str, + User_Info=str(request.user_info.model_dump()), + ) + + llm_response = rag_gen.generator.generate_text(prompt=final_prompt) + + return { + "session_id": request.session_id, # Return as is + "route": route, + "query": user_query, + "history": request.history, # Return as is + "results": results, + "LLM_answer": llm_response, + } diff --git a/routes/base.py b/routes/base.py new file mode 100644 index 0000000000000000000000000000000000000000..977de5bd350d38afbb2c823c3944241bbc054afa --- /dev/null +++ b/routes/base.py @@ -0,0 +1,45 @@ +from fastapi import APIRouter , Depends +from config import get_settings +from indexing.indexingController import IndexingController + +base_router = APIRouter(tags=["base"]) + + +@base_router.get("/health") +async def health_check(settings = Depends(get_settings)): + return {"status": "ok", "app_name": settings} + +# @base_router.post("/all_docs") +# async def get_all_docs(): +# indexing_controller = IndexingController() +# all_docs = indexing_controller.vector_store.get_all_documents() +# return { +# "total_docs": len(all_docs), +# "documents": all_docs +# } + +@base_router.get("/all_files") +async def get__files(): + indexing_controller = IndexingController() + all_files = indexing_controller.vector_store.get_all_files() + return { + "total_files": len(all_files), + "files": all_files,} + + +@base_router.get("/remove_file") +async def remove_file(filename: str,username: str ,course: str): + indexing_controller = IndexingController() + result = indexing_controller.vector_store.remove_points_by_file(filename,username,course) + return { + "status": "success" if result else "failure", + "message": f"File '{filename}' removed." if result else f"File '{filename}' not found." + } + +@base_router.get("/user/docs") +async def get_user_docs(username: str): + indexing_controller = IndexingController() + user_docs = indexing_controller.vector_store.all_user_files_bookmarks(username) + return { + "total_docs": len(user_docs), + "documents": user_docs} \ No newline at end of file diff --git a/routes/exam_grading_router.py b/routes/exam_grading_router.py new file mode 100644 index 0000000000000000000000000000000000000000..f4cbb96a5cf7a8e0736b2ccc48331f920e66c7ed --- /dev/null +++ b/routes/exam_grading_router.py @@ -0,0 +1,122 @@ +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel +import re + +from generation.ExamAnswer import ExamGradingService, grade_exam_task +from generation.answer_models import ExamSubmission, ExamResult +from routes.schemas.Exam_Models import ExamResponse + +grading_router = APIRouter(prefix="/exam/grading", tags=["exam_grading"]) + +class GradingResponse(BaseModel): + job_id: str + exam_id: str + student_id: str + status: str + +class GradingRequest(BaseModel): + submission: ExamSubmission + exam: ExamResponse + + +def normalize_text(text: str) -> str: + if not text: + return "" + text = re.sub(r'[^\w\s]', '', text) + text = re.sub(r'\s+', ' ', text) + return text.strip().lower() + +@grading_router.post("/submit", response_model=GradingResponse) +async def submit_exam(request: GradingRequest): + submission_dict = request.submission.model_dump() + exam_questions_map = {} + + for q in request.exam.questions: + normalized_q = normalize_text(q.question) + exam_questions_map[normalized_q] = q + + for answer in submission_dict["answers"]: + question_text = answer["question_text"] + question_type = answer["question_type"] + normalized_answer_text = normalize_text(question_text) + + + correct_answer = None + if normalized_answer_text in exam_questions_map: + q = exam_questions_map[normalized_answer_text] + + if question_type == "multiple_choice" and hasattr(q, 'correct_answer'): + correct_answer = q.correct_answer + elif question_type == "true_false" and hasattr(q, 'correct_answer'): + correct_answer = q.correct_answer + elif question_type == "short_answer" and hasattr(q, 'answer'): + correct_answer = q.answer + elif question_type == "code" and hasattr(q, 'solution'): + correct_answer = q.solution + elif question_type == "essay": + if hasattr(q, 'answer_guidelines') and q.answer_guidelines: + correct_answer = q.answer_guidelines + elif hasattr(q, 'answer'): + correct_answer = q.answer + + + if "metadata" not in answer: + answer["metadata"] = {} + answer["metadata"]["correct_answer"] = correct_answer + + + task = grade_exam_task.delay(submission_dict) + + return GradingResponse( + job_id=task.id, + exam_id=request.submission.exam_id, + student_id=request.submission.student_id, + status="queued" + ) + +@grading_router.post("/grade-sync", response_model=ExamResult) +async def grade_sync(request: GradingRequest): + try: + service = ExamGradingService(use_ai_for_essays=True) + + + exam_questions_map = {} + for q in request.exam.questions: + normalized_q = normalize_text(q.question) + exam_questions_map[normalized_q] = q + + for ans in request.submission.answers: + question_text = ans.question_text + question_type = ans.question_type + normalized_answer_text = normalize_text(question_text) + + + correct_answer = None + if normalized_answer_text in exam_questions_map: + q = exam_questions_map[normalized_answer_text] + + + if question_type == "multiple_choice" and hasattr(q, 'correct_answer'): + correct_answer = q.correct_answer + elif question_type == "true_false" and hasattr(q, 'correct_answer'): + correct_answer = q.correct_answer + elif question_type == "short_answer" and hasattr(q, 'answer'): + correct_answer = q.answer + elif question_type == "code" and hasattr(q, 'solution'): + correct_answer = q.solution + elif question_type == "essay": + if hasattr(q, 'answer_guidelines') and q.answer_guidelines: + correct_answer = q.answer_guidelines + elif hasattr(q, 'answer'): + correct_answer = q.answer + + if correct_answer is not None: + if not ans.metadata: + ans.metadata = {} + ans.metadata["correct_answer"] = correct_answer + + + result = service.grade_submission(request.submission) + return result + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) \ No newline at end of file diff --git a/routes/exam_router.py b/routes/exam_router.py new file mode 100644 index 0000000000000000000000000000000000000000..60f7b9cf296e32eb7dca00d1dceb46ec7981fe53 --- /dev/null +++ b/routes/exam_router.py @@ -0,0 +1,15 @@ +from fastapi import APIRouter +from routes.schemas.Exam_Models import ExamGenerationRequest +from worker.tasks import generate_exam_task + +exam_router = APIRouter(prefix="/exam", tags=["exam"]) + + +@exam_router.post("/create") +async def process_file_endpoint(request: ExamGenerationRequest): + task = generate_exam_task.delay(request.model_dump()) + return { + "job_id": task.id, + "exam_id": request.exam_id, + "status": "queued", + } diff --git a/routes/schemas/Exam_Models.py b/routes/schemas/Exam_Models.py new file mode 100644 index 0000000000000000000000000000000000000000..4f79c79a83dae394cf62dad30b5deeecfd11b147 --- /dev/null +++ b/routes/schemas/Exam_Models.py @@ -0,0 +1,180 @@ +from pydantic import BaseModel, field_validator, model_validator +from typing import List, Optional, Dict +from enum import Enum +from typing import Union +from typing import Literal +from pydantic import Field +from typing import Annotated + + +class QuestionType(str, Enum): + MCQ = "mcq" + TRUE_FALSE = "true_false" + SHORT_ANSWER = "short_answer" + ESSAY = "essay" + CODE = "code" + +class DifficultyLevel(str, Enum): + EASY = "easy" + MEDIUM = "medium" + HARD = "hard" + +class Reference(BaseModel): + filename: str + bookmarks: Optional[List[str]] = None + +class ExamGenerationRequest(BaseModel): + username: str + course: str + exam_id: str + total_questions: int + topics: List[str] + references: Optional[List[Reference]] = None + difficulty: Optional[DifficultyLevel] = DifficultyLevel.MEDIUM + include_answer_key: Optional[bool] = True + question_types_distribution: Dict[QuestionType, int] + model_config = {"extra": "ignore"} + + @field_validator("topics") + @classmethod + def validate_topics(cls, v): + if not v: + raise ValueError("Topics cannot be empty") + return v + + @field_validator("question_types_distribution") + @classmethod + def validate_positive(cls, v): + if any(count <= 0 for count in v.values()): + raise ValueError("All distribution counts must be > 0") + return v + + @model_validator(mode="after") + def validate_sum(self): + if sum(self.question_types_distribution.values()) != self.total_questions: + raise ValueError("Distribution must equal total_questions") + return self + +class QuestionBase(BaseModel): + type: QuestionType + question: str + model_config = {"extra": "ignore"} + +class MCQQuestion(QuestionBase): + type: Literal[QuestionType.MCQ] + options: List[str] + correct_answer: str + explanation: str + + @model_validator(mode="after") + def validate_mcq(self): + if len(self.options) < 2: + raise ValueError("MCQ must contain at least 2 options") + if self.correct_answer not in self.options: + raise ValueError("correct_answer must exist in options") + return self + +class TrueFalseQuestion(QuestionBase): + type: Literal[QuestionType.TRUE_FALSE] + correct_answer: bool + explanation: str + +class ShortAnswerQuestion(QuestionBase): + type: Literal[QuestionType.SHORT_ANSWER] + answer: str + explanation: str + +class EssayQuestion(QuestionBase): + type: Literal[QuestionType.ESSAY] + answer: str + answer_guidelines: str + +class CodeQuestion(QuestionBase): + type: Literal[QuestionType.CODE] + + starter_code: Optional[str] = Field( + default=None, + description="Starter code shown to the student" + ) + + language: str= "c" + + solution: str = Field( + description="Correct full solution code" + ) + + explanation: str = Field( + description="Explanation of how the solution works" + ) + + @field_validator("starter_code", "solution") + @classmethod + def normalize_code(cls, v): + """Convert escaped newlines to real newlines if present.""" + if v: + return v.replace("\\n", "\n") + return v + +QuestionUnion = Annotated[ + Union[ + MCQQuestion, + TrueFalseQuestion, + ShortAnswerQuestion, + EssayQuestion, + CodeQuestion, + ], + Field(discriminator="type"), +] + +class ExamResponse(BaseModel): + exam_id: str + difficulty: DifficultyLevel + total_questions: int + questions: List[QuestionUnion] + expected_distribution: Dict[QuestionType, int] + model_config = {"extra": "ignore"} + + @model_validator(mode="after") + def validate_question_count(self): + if len(self.questions) != self.total_questions: + raise ValueError( + f"Expected {self.total_questions} questions, " + f"but got {len(self.questions)}" + ) + return self + @model_validator(mode="after") + def validate_distribution(self): + + actual_counts: Dict[QuestionType, int] = {} + + for q in self.questions: + actual_counts[q.type] = actual_counts.get(q.type, 0) + 1 + + if set(actual_counts.keys()) != set(self.expected_distribution.keys()): + raise ValueError("Unexpected question types in exam") + + for q_type, expected_count in self.expected_distribution.items(): + actual = actual_counts.get(q_type, 0) + + if actual != expected_count: + raise ValueError( + f"Distribution mismatch for {q_type.value}: " + f"expected {expected_count}, got {actual}" + ) + + return self + +class AnswerItem(BaseModel): + question_index: int + answer: str + +class AnswerKey(BaseModel): + exam_id: str + answers: List[AnswerItem] + model_config = {"extra": "ignore"} + +class EvaluationResult(BaseModel): + overall_score: int + feedback: str + model_config = {"extra": "ignore"} + \ No newline at end of file diff --git a/routes/schemas/Requests_Models.py b/routes/schemas/Requests_Models.py new file mode 100644 index 0000000000000000000000000000000000000000..a1c0cbb42e65c75016f9e2ea176add54298dfde6 --- /dev/null +++ b/routes/schemas/Requests_Models.py @@ -0,0 +1,24 @@ +from pydantic import BaseModel +from typing import Optional, List + +class ConversationTurn(BaseModel): + Human_msg: str + LLM_response: str + +class UserInfoRequest(BaseModel): + courses: Optional[List[str]] = None + deadlines: Optional[List[str]] = None + grades: Optional[List[str]] = None + instructor_owned_files: Optional[List[str]] = None + more_info: Optional[str] = None + +class ChatRequest(BaseModel): + prompt: Optional[str] = None + username: str + session_id: str + role: str + top_k: int = 5 + source_file: Optional[str] = None + bookmark: Optional[str] = None + history: Optional[List[ConversationTurn]] = None + user_info: Optional[UserInfoRequest]= None diff --git a/routes/schemas/__init__.py b/routes/schemas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stores/llm/LLMEnums.py b/stores/llm/LLMEnums.py new file mode 100644 index 0000000000000000000000000000000000000000..af710979ddcc10b3ce93c94af3ec12778463623d --- /dev/null +++ b/stores/llm/LLMEnums.py @@ -0,0 +1,28 @@ +from enum import Enum + +class LLMEnums(Enum): + OPENAI = "OPENAI" + COHERE = "COHERE" + OLLAMA = "OLLAMA" + MISTRAL = "MISTRAL" + GROQ = "GROQ" + OPENROUTER = "OPENROUTER" + HUGGINGFACE = "HUGGINGFACE" + DEEPSEEK = "DEEPSEEK" + GEMINI = "GEMINI" + +class OpenAIEnums(Enum): + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + +class CoHereEnums(Enum): + SYSTEM = "SYSTEM" + USER = "USER" + ASSISTANT = "CHATBOT" + DOCUMENT = "search_document" + QUERY = "search_query" + +class DocumentTypeEnum(Enum): + DOCUMENT = "document" + QUERY = "query" diff --git a/stores/llm/LLMInterface.py b/stores/llm/LLMInterface.py new file mode 100644 index 0000000000000000000000000000000000000000..a634501d68179776e23ffe15d79372612cb5efd2 --- /dev/null +++ b/stores/llm/LLMInterface.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod + +class LLMInterface(ABC): + + @abstractmethod + def set_generation_model(self, model_id: str): + pass + + @abstractmethod + def set_embedding_model(self, model_id: str, embedding_size: int): + pass + + @abstractmethod + def generate_text(self, prompt: str, chat_history: list=[], max_output_tokens: int=None, + temperature: float = None): + pass + + @abstractmethod + def embed_text_batch(self, texts: list[str], batch_size: int = 32): + pass + + @abstractmethod + def construct_prompt(self, prompt: str, role: str): + pass diff --git a/stores/llm/LLMProviderFactory.py b/stores/llm/LLMProviderFactory.py new file mode 100644 index 0000000000000000000000000000000000000000..df83dc3f60ef6f2e4f01f6220dcaa8f3b352da7c --- /dev/null +++ b/stores/llm/LLMProviderFactory.py @@ -0,0 +1,93 @@ +from .LLMEnums import LLMEnums +from stores.llm.providers.OpenAIProvider import OpenAIProvider +from stores.llm.providers.OllamaProvider import OllamaProvider +from stores.llm.providers.CohereProvider import CohereProvider +from stores.llm.providers.MistralProvider import MistralProvider +from stores.llm.providers.GroqProvider import GroqProvider +from stores.llm.providers.OpenRouterProvider import OpenRouterProvider +from stores.llm.providers.HuggingFaceProvider import HuggingFaceProvider +from stores.llm.providers.DeepSeekProvider import DeepSeekProvider +from stores.llm.providers.GeminiProvider import GeminiProvider + + +class LLMProviderFactory: + def __init__(self, config: dict): + self.config = config + + def create(self, provider: str): + + if provider == LLMEnums.OPENAI.value: + return OpenAIProvider( + api_key=self.config.OPENAI_API_KEY, + api_url=self.config.OPENAI_API_URL, + default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, + default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, + default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, + ) + + if provider == LLMEnums.OLLAMA.value: + return OllamaProvider( + url=self.config.OLLAMA_URL, + api_key=self.config.OLLAMA_API_KEY, + default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, + default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, + default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, + ) + + if provider == LLMEnums.COHERE.value: + return CohereProvider( + api_key=self.config.COHERE_API_KEY, + default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, + default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, + default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, + ) + + if provider == LLMEnums.MISTRAL.value: + return MistralProvider( + api_key=self.config.MISTRAL_API_KEY, + default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, + default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, + default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, + ) + + if provider == LLMEnums.GROQ.value: + return GroqProvider( + api_key=self.config.GROQ_API_KEY, + default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, + default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, + default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, + ) + + if provider == LLMEnums.OPENROUTER.value: + return OpenRouterProvider( + api_key=self.config.OPENROUTER_API_KEY, + default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, + default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, + default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, + ) + + if provider == LLMEnums.HUGGINGFACE.value: + return HuggingFaceProvider( + api_key=self.config.HF_API_KEY, + default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, + default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, + default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, + ) + + if provider == LLMEnums.DEEPSEEK.value: + return DeepSeekProvider( + api_key=self.config.DEEPSEEK_API_KEY, + default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, + default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, + default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, + ) + + if provider == LLMEnums.GEMINI.value: + return GeminiProvider( + api_key=self.config.GEMINI_API_KEY, + default_input_max_characters=self.config.INPUT_DAFAULT_MAX_CHARACTERS, + default_generation_max_output_tokens=self.config.GENERATION_DAFAULT_MAX_TOKENS, + default_generation_temperature=self.config.GENERATION_DAFAULT_TEMPERATURE, + ) + + return None diff --git a/stores/llm/__init__.py b/stores/llm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stores/llm/providers/CohereProvider.py b/stores/llm/providers/CohereProvider.py new file mode 100644 index 0000000000000000000000000000000000000000..ea96269e0ee11c28cd2626f4a07d8f17358dda92 --- /dev/null +++ b/stores/llm/providers/CohereProvider.py @@ -0,0 +1,395 @@ +from stores.llm.LLMInterface import LLMInterface +import logging +import requests +import re +import os +import time +import math +class CohereProvider(LLMInterface): + def __init__(self, url: str = None, model: str = None, + default_input_max_characters: int = 1000, + default_generation_max_output_tokens: int = 1000, + default_generation_temperature: float = 0.1, api_key: str = None): + self.url = url or "https://api.cohere.com/v2" + self.api_key = api_key or os.getenv("COHERE_API_KEY") + self.model = model + self.generation_model_id = None + + self.embedding_model = None + self.embedding_model_id = None + self.embedding_size = None + + self.default_input_max_characters = default_input_max_characters + self.default_generation_max_output_tokens = default_generation_max_output_tokens + self.default_generation_temperature = default_generation_temperature + self.logger = logging.getLogger(__name__) + + def set_generation_model(self, model_id: str): + if model_id: + self.model = model_id + + def set_embedding_model(self, model_id: str, embedding_size: int): + if model_id: + self.embedding_model = model_id + self.embedding_size = embedding_size + self.embedding_model_id = model_id + + def process_text(self, text: str): + if not text: + return "" + return str(text).strip() + + def generate_text(self, prompt: str, chat_history: list = None, + max_output_tokens: int = None, temperature: float = None): + try: + chat_history = chat_history or [] # safe handling + clean_prompt = self.process_text(prompt) + + # Build messages list from chat_history + current prompt + messages = [] + for entry in chat_history: + messages.append({ + "role": entry.get("role", "user"), + "content": entry.get("content", "") + }) + messages.append({"role": "user", "content": clean_prompt}) + + payload = { + "model": self.model, + "messages": messages, + "max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens), + "temperature": float(temperature or self.default_generation_temperature), + } + + url = self.url.rstrip("/") + "/chat" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + resp = requests.post(url, json=payload, headers=headers, timeout=6000) + if resp.status_code != 200: + self.logger.error("Cohere generate failed: %s %s", resp.status_code, resp.text) + return None + + data = resp.json() + + # Extract generated text from Cohere v2 chat response + generated_text = "" + try: + generated_text = data["message"]["content"][0]["text"].strip() + except (KeyError, IndexError, TypeError): + self.logger.error("Unexpected Cohere response structure: %s", data) + return None + + if not generated_text: + return None + + # Mirror the same return shape as OllamaProvider + usage = data.get("usage", {}) + return { + "model": data.get("model"), + "response": generated_text, + "tokens_generated": usage.get("tokens", {}).get("output_tokens"), + "total_duration_ms": None, # Cohere does not expose latency in response + "prompt_eval_tokens": usage.get("tokens", {}).get("input_tokens"), + } + + except Exception as e: + self.logger.exception("Error in CohereProvider.generate_text: %s", e) + return None + + def embed_text(self, text: str, document_type: str = None): + """Return an embedding vector from Cohere.""" + try: + if not self.embedding_model: + self.logger.error("Embedding model is not set before calling embed_text()") + return None + + clean_text = self.process_text(text) + print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'") + if not clean_text: + return [] + + # Cohere requires an input_type; map document_type or fall back to "search_document" + input_type = document_type if document_type in ( + "search_document", "search_query", "classification", "clustering" + ) else "search_document" + + payload = { + "model": self.embedding_model, + "texts": [clean_text], + "input_type": input_type, + "embedding_types": ["float"], + } + + url = self.url.rstrip("/") + "/embed" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + resp = requests.post(url, json=payload, headers=headers, timeout=200) + if resp.status_code != 200: + print(f"[ERROR] Cohere embedding failed: {resp.status_code} {resp.text}") + return None + + data = resp.json() + + # Cohere v2 returns embeddings under data.embeddings.float + embedding = None + try: + embedding = data["embeddings"]["float"][0] + except (KeyError, IndexError, TypeError): + pass + + # Fallback: older v1-style shape + if embedding is None: + try: + embedding = data["embeddings"][0] + except (KeyError, IndexError, TypeError): + pass + + if embedding is not None: + print(f"[DEBUG] Embedding length: {len(embedding)}") + return embedding + + print("[WARNING] 'embedding' key not found in response JSON") + return None + + except Exception as e: + print(f"[EXCEPTION] Error in CohereProvider.embed_text: {e}") + return None + + def construct_prompt(self, prompt: str, role: str): + return { + "role": role, + "content": self.process_text(prompt) + } + def embed_text_batch(self, texts: list[str], batch_size: int = 96): + self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}") + + if not self.embedding_model: + self.logger.error("Embedding model not set") + return None + + all_embeddings = [] + total_batches = math.ceil(len(texts) / batch_size) + + url = self.url.rstrip("/") + "/embed" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + # Cohere free tier: 10 req/min | paid: 100 req/min + # Adjust MIN_SECONDS_PER_REQUEST to match your plan + MIN_SECONDS_PER_REQUEST = 0.65 # ~92 req/min (safe under 100/min paid) + MAX_RETRIES = 5 + BACKOFF_BASE = 10 # seconds — doubles on each retry + + for batch_idx, i in enumerate(range(0, len(texts), batch_size), start=1): + time.sleep(6) + batch = texts[i:i + batch_size] + clean_batch = [self.process_text(t) for t in batch if t] + + # ── Progress ──────────────────────────────────────────────────────── + done_texts = min(i + batch_size, len(texts)) + pct = (batch_idx / total_batches) * 100 + bar_filled = int(pct / 5) # 20-char bar + bar = "█" * bar_filled + "░" * (20 - bar_filled) + print( + f"\r[EMBED] [{bar}] {pct:5.1f}% " + f"batch {batch_idx}/{total_batches} " + f"({done_texts}/{len(texts)} texts)", + end="", flush=True + ) + # ──────────────────────────────────────────────────────────────────── + + payload = { + "model": self.embedding_model, + "texts": clean_batch, + "input_type": "search_document", + "embedding_types": ["float"], + } + + # ── Rate-limited request with exponential back-off ────────────────── + embeddings = None + request_start = time.monotonic() + + for attempt in range(1, MAX_RETRIES + 1): + resp = requests.post(url, json=payload, headers=headers, timeout=200) + + if resp.status_code == 200: + break + + if resp.status_code == 429: + retry_after = float(resp.headers.get("Retry-After", BACKOFF_BASE ** attempt)) + print( + f"\n[RATE LIMIT] batch {batch_idx} — " + f"attempt {attempt}/{MAX_RETRIES}, " + f"waiting {retry_after:.1f}s …" + ) + time.sleep(retry_after) + continue + + # Any other non-200 — log and abort + self.logger.error( + "Cohere embedding failed (batch %d, attempt %d): %s %s", + batch_idx, attempt, resp.status_code, resp.text + ) + return None + + else: + # Exhausted all retries on 429 + self.logger.error( + "Cohere embedding: max retries (%d) exceeded on batch %d", + MAX_RETRIES, batch_idx + ) + return None + + # ── Parse response ────────────────────────────────────────────────── + data = resp.json() + + try: + embeddings = data["embeddings"]["float"] # v2 shape + except (KeyError, TypeError): + embeddings = data.get("embeddings") # v1 shape + + if not embeddings: + self.logger.error("No embeddings returned from Cohere (batch %d)", batch_idx) + return None + + self.logger.debug(f"Received {len(embeddings)} embeddings for batch {batch_idx}") + all_embeddings.extend(embeddings) + + # ── Pace requests to stay under rate limit ────────────────────────── + elapsed = time.monotonic() - request_start + sleep_for = max(0.0, MIN_SECONDS_PER_REQUEST - elapsed) + if sleep_for > 0: + time.sleep(sleep_for) + # ──────────────────────────────────────────────────────────────────── + + # Final newline after the progress bar + print(f"\r[EMBED] [{'█' * 20}] 100.0% " + f"batch {total_batches}/{total_batches} " + f"({len(texts)}/{len(texts)} texts) ✓") + + self.logger.info(f"Total embeddings created: {len(all_embeddings)}") + return all_embeddings + # def embed_text_batch(self, texts: list[str], batch_size: int = 32): + # self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}") + + # if not self.embedding_model: + # self.logger.error("Embedding model not set") + # return None + + # all_embeddings = [] + + # url = self.url.rstrip("/") + "/embed" + # headers = { + # "Authorization": f"Bearer {self.api_key}", + # "Content-Type": "application/json", + # } + + # for i in range(0, len(texts), batch_size): + # batch = texts[i:i + batch_size] + # clean_batch = [self.process_text(t) for t in batch if t] + + # print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size}") + + # payload = { + # "model": self.embedding_model, + # "texts": clean_batch, + # "input_type": "search_document", + # "embedding_types": ["float"], + # } + + # resp = requests.post(url, json=payload, headers=headers, timeout=200) + # if resp.status_code != 200: + # self.logger.error("Cohere embedding failed: %s %s", resp.status_code, resp.text) + # return None + + # data = resp.json() + + # # Handle both v2 (embeddings.float) and v1 (embeddings) shapes + # embeddings = None + # try: + # embeddings = data["embeddings"]["float"] + # except (KeyError, TypeError): + # embeddings = data.get("embeddings") + + # if not embeddings: + # self.logger.error("No embeddings returned from Cohere") + # return None + + # self.logger.debug(f"Received {len(embeddings)} embeddings") + # all_embeddings.extend(embeddings) + + # self.logger.info(f"Total embeddings created: {len(all_embeddings)}") + # return all_embeddings + + def clean_content(self, text: str) -> str: + text = re.sub(r'\[.*?\]\(.*?\)', '', text) + text = re.sub(r'\[[^\]]*\]', '', text) + text = re.sub(r'\n+', '\n', text).strip() + return text + + def web_search(self, query: str): + """Use Cohere's chat endpoint with web-search connector to perform a search.""" + try: + payload = { + "model": self.model, + "messages": [{"role": "user", "content": query}], + "tools": [{"type": "web_search"}], + } + + url = self.url.rstrip("/") + "/chat" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + resp = requests.post(url, json=payload, headers=headers, timeout=6000) + + if not resp or resp.status_code != 200: + return { + "text": "No relevant external results found.", + "references": [] + } + + data = resp.json() + + combined_text = [] + references = set() + + # Extract assistant text + try: + assistant_text = data["message"]["content"][0]["text"] + combined_text.append(self.clean_content(assistant_text)) + except (KeyError, IndexError, TypeError): + pass + + # Extract citations / source URLs from Cohere's citations block + for citation in data.get("message", {}).get("citations", []): + for source in citation.get("sources", []): + url_val = source.get("url") or source.get("id", "") + if url_val.startswith("http"): + references.add(url_val) + + # Also scan raw text for bare URLs (mirrors Ollama behaviour) + raw_text = "\n".join(combined_text) + for found_url in re.findall(r"https?://[^\s)]+", raw_text): + references.add(found_url) + + return { + "text": "\n\n".join(combined_text[:3]), + "references": list(references) + } + + except Exception as e: + self.logger.error("Cohere web search failed: %s", e) + return { + "text": f"Cohere search error: {str(e)}", + "references": [] + } diff --git a/stores/llm/providers/DeepSeekProvider.py b/stores/llm/providers/DeepSeekProvider.py new file mode 100644 index 0000000000000000000000000000000000000000..5fc890abb26bd8196cd5d3a1022be09d1c3698c8 --- /dev/null +++ b/stores/llm/providers/DeepSeekProvider.py @@ -0,0 +1,126 @@ +from stores.llm.LLMInterface import LLMInterface +import logging +import requests +import re +import os + + +class DeepSeekProvider(LLMInterface): + def __init__(self, url: str = None, model: str = None, + default_input_max_characters: int = 1000, + default_generation_max_output_tokens: int = 1000, + default_generation_temperature: float = 0.1, api_key: str = None): + self.url = url or "https://api.deepseek.com/v1" + self.api_key = api_key or os.getenv("DEEPSEEK_API_KEY") + self.model = model + self.generation_model_id = None + + self.embedding_model = None + self.embedding_model_id = None + self.embedding_size = None + + self.default_input_max_characters = default_input_max_characters + self.default_generation_max_output_tokens = default_generation_max_output_tokens + self.default_generation_temperature = default_generation_temperature + self.logger = logging.getLogger(__name__) + + def set_generation_model(self, model_id: str): + if model_id: + self.model = model_id + + def set_embedding_model(self, model_id: str, embedding_size: int): + if model_id: + self.embedding_model = model_id + self.embedding_size = embedding_size + self.embedding_model_id = model_id + + def process_text(self, text: str): + if not text: + return "" + return str(text).strip() + + def generate_text(self, prompt: str, chat_history: list = None, + max_output_tokens: int = None, temperature: float = None): + try: + chat_history = chat_history or [] + clean_prompt = self.process_text(prompt) + + messages = [] + for entry in chat_history: + messages.append({ + "role": entry.get("role", "user"), + "content": entry.get("content", "") + }) + messages.append({"role": "user", "content": clean_prompt}) + + payload = { + "model": self.model, + "messages": messages, + "max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens), + "temperature": float(temperature or self.default_generation_temperature), + } + + url = self.url.rstrip("/") + "/chat/completions" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + resp = requests.post(url, json=payload, headers=headers, timeout=6000) + if resp.status_code != 200: + self.logger.error("DeepSeek generate failed: %s %s", resp.status_code, resp.text) + return None + + data = resp.json() + + try: + generated_text = data["choices"][0]["message"]["content"].strip() + except (KeyError, IndexError, TypeError): + self.logger.error("Unexpected DeepSeek response structure: %s", data) + return None + + if not generated_text: + return None + + usage = data.get("usage", {}) + return { + "model": data.get("model"), + "response": generated_text, + "tokens_generated": usage.get("completion_tokens"), + "total_duration_ms": None, + "prompt_eval_tokens": usage.get("prompt_tokens"), + } + + except Exception as e: + self.logger.exception("Error in DeepSeekProvider.generate_text: %s", e) + return None + + def embed_text(self, text: str, document_type: str = None): + """DeepSeek does not currently offer an embeddings endpoint — returns None.""" + self.logger.warning("DeepSeekProvider does not support embeddings.") + return None + + def construct_prompt(self, prompt: str, role: str): + return { + "role": role, + "content": self.process_text(prompt) + } + + def embed_text_batch(self, texts: list[str], batch_size: int = 32): + """DeepSeek does not currently offer an embeddings endpoint — returns None.""" + self.logger.warning("DeepSeekProvider does not support embeddings.") + return None + + def clean_content(self, text: str) -> str: + text = re.sub(r'\[.*?\]\(.*?\)', '', text) + text = re.sub(r'\[[^\]]*\]', '', text) + text = re.sub(r'\n+', '\n', text).strip() + return text + + def web_search(self, query: str): + """DeepSeek has no native web search — returns a not-supported notice.""" + self.logger.warning("DeepSeekProvider.web_search is not natively supported.") + return { + "text": "Web search is not natively supported by the DeepSeek API.", + "references": [] + } diff --git a/stores/llm/providers/GeminiProvider.py b/stores/llm/providers/GeminiProvider.py new file mode 100644 index 0000000000000000000000000000000000000000..619ec6f7c0ce88ecb9fef96eac00a8e56773390d --- /dev/null +++ b/stores/llm/providers/GeminiProvider.py @@ -0,0 +1,305 @@ +import time + +from stores.llm.LLMInterface import LLMInterface +import logging +import requests +import re +import os + + +class GeminiProvider(LLMInterface): + def __init__(self, url: str = None, model: str = None, + default_input_max_characters: int = 1000, + default_generation_max_output_tokens: int = 1000, + default_generation_temperature: float = 0.1, api_key: str = None): + self.url = url or "https://generativelanguage.googleapis.com/v1beta" + self.api_key = api_key or os.getenv("GEMINI_API_KEY") + self.model = model + self.generation_model_id = None + + self.embedding_model = None + self.embedding_model_id = None + self.embedding_size = None + + self.default_input_max_characters = default_input_max_characters + self.default_generation_max_output_tokens = default_generation_max_output_tokens + self.default_generation_temperature = default_generation_temperature + self.logger = logging.getLogger(__name__) + + def set_generation_model(self, model_id: str): + if model_id: + self.model = model_id + + def set_embedding_model(self, model_id: str, embedding_size: int): + if model_id: + self.embedding_model = model_id + self.embedding_size = embedding_size + self.embedding_model_id = model_id + + def process_text(self, text: str): + if not text: + return "" + return str(text).strip() + + def _build_contents(self, prompt: str, chat_history: list) -> list: + """Convert chat_history + prompt into Gemini's contents format.""" + contents = [] + for entry in chat_history: + role = entry.get("role", "user") + # Gemini uses 'model' instead of 'assistant' + if role == "assistant": + role = "model" + contents.append({ + "role": role, + "parts": [{"text": entry.get("content", "")}] + }) + contents.append({ + "role": "user", + "parts": [{"text": prompt}] + }) + return contents + + def generate_text(self, prompt: str, chat_history: list = None, + max_output_tokens: int = None, temperature: float = None): + try: + chat_history = chat_history or [] + clean_prompt = self.process_text(prompt) + + contents = self._build_contents(clean_prompt, chat_history) + + payload = { + "contents": contents, + "generationConfig": { + "maxOutputTokens": int(max_output_tokens or self.default_generation_max_output_tokens), + "temperature": float(temperature or self.default_generation_temperature), + } + } + + url = ( + f"{self.url.rstrip('/')}/models/{self.model}" + f":generateContent?key={self.api_key}" + ) + headers = {"Content-Type": "application/json"} + + resp = requests.post(url, json=payload, headers=headers, timeout=6000) + if resp.status_code != 200: + self.logger.error("Gemini generate failed: %s %s", resp.status_code, resp.text) + return None + + data = resp.json() + + try: + generated_text = ( + data["candidates"][0]["content"]["parts"][0]["text"].strip() + ) + except (KeyError, IndexError, TypeError): + self.logger.error("Unexpected Gemini response structure: %s", data) + return None + + if not generated_text: + return None + + usage = data.get("usageMetadata", {}) + return { + "model": self.model, + "response": generated_text, + "tokens_generated": usage.get("candidatesTokenCount"), + "total_duration_ms": None, + "prompt_eval_tokens": usage.get("promptTokenCount"), + } + + except Exception as e: + self.logger.exception("Error in GeminiProvider.generate_text: %s", e) + return None + + def embed_text(self, text: str, document_type: str = None): + try: + if not self.embedding_model: + self.logger.error("Embedding model is not set before calling embed_text()") + return None + + clean_text = self.process_text(text) + print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'") + if not clean_text: + return [] + + # Map document_type to Gemini task type + task_type_map = { + "search_document": "RETRIEVAL_DOCUMENT", + "search_query": "RETRIEVAL_QUERY", + "classification": "CLASSIFICATION", + "clustering": "CLUSTERING", + } + task_type = task_type_map.get(document_type, "RETRIEVAL_DOCUMENT") + + payload = { + "model": f"models/{self.embedding_model}", + "content": {"parts": [{"text": clean_text}]}, + "output_dimensionality": 768, + "taskType": task_type, + } + + url = ( + f"{self.url.rstrip('/')}/models/{self.embedding_model}" + f":embedContent?key={self.api_key}" + ) + headers = {"Content-Type": "application/json"} + + resp = requests.post(url, json=payload, headers=headers, timeout=200) + if resp.status_code != 200: + print(f"[ERROR] Gemini embedding failed: {resp.status_code} {resp.text}") + return None + + data = resp.json() + + try: + embedding = data["embedding"]["values"] + print(f"[DEBUG] Embedding length: {len(embedding)}") + return embedding + except (KeyError, TypeError): + print("[WARNING] 'embedding' key not found in response JSON") + return None + + except Exception as e: + print(f"[EXCEPTION] Error in GeminiProvider.embed_text: {e}") + return None + + def construct_prompt(self, prompt: str, role: str): + return { + "role": role, + "content": self.process_text(prompt) + } + + def embed_text_batch(self, texts: list[str], batch_size: int = 32): + self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}") + + if not self.embedding_model: + self.logger.error("Embedding model not set") + return None + + all_embeddings = [] + + url = ( + f"{self.url.rstrip('/')}/models/{self.embedding_model}" + f":batchEmbedContents?key={self.api_key}" + ) + headers = {"Content-Type": "application/json"} + + for i in range(0, len(texts), batch_size): + time.sleep(5) + batch = texts[i:i + batch_size] + clean_batch = [self.process_text(t) for t in batch if t] + + print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size}") + + # Gemini batchEmbedContents takes a list of requests + requests_list = [ + { + "model": f"models/{self.embedding_model}", + "content": {"parts": [{"text": t}]}, + "taskType": "RETRIEVAL_DOCUMENT", + "output_dimensionality": 768, # ← add this + } + for t in clean_batch + ] + payload = {"requests": requests_list} + + resp = requests.post(url, json=payload, headers=headers, timeout=200) + if resp.status_code != 200: + self.logger.error("Gemini embedding failed: %s %s", resp.status_code, resp.text) + return None + + data = resp.json() + + try: + embeddings = [item["values"] for item in data["embeddings"]] + except (KeyError, TypeError): + self.logger.error("No embeddings returned from Gemini") + return None + + if not embeddings: + self.logger.error("No embeddings returned from Gemini") + return None + + self.logger.debug(f"Received {len(embeddings)} embeddings") + all_embeddings.extend(embeddings) + + self.logger.info(f"Total embeddings created: {len(all_embeddings)}") + return all_embeddings + + def clean_content(self, text: str) -> str: + text = re.sub(r'\[.*?\]\(.*?\)', '', text) + text = re.sub(r'\[[^\]]*\]', '', text) + text = re.sub(r'\n+', '\n', text).strip() + return text + + def web_search(self, query: str): + """ + Gemini supports Google Search grounding via the tools parameter. + Uses generateContent with the googleSearch tool enabled. + """ + try: + payload = { + "contents": [{"role": "user", "parts": [{"text": query}]}], + "tools": [{"google_search": {}}], + "generationConfig": { + "maxOutputTokens": int(self.default_generation_max_output_tokens), + "temperature": float(self.default_generation_temperature), + } + } + + url = ( + f"{self.url.rstrip('/')}/models/{self.model}" + f":generateContent?key={self.api_key}" + ) + headers = {"Content-Type": "application/json"} + + resp = requests.post(url, json=payload, headers=headers, timeout=6000) + + if not resp or resp.status_code != 200: + return { + "text": "No relevant external results found.", + "references": [] + } + + data = resp.json() + + combined_text = [] + references = set() + + try: + text_content = data["candidates"][0]["content"]["parts"][0]["text"] + combined_text.append(self.clean_content(text_content)) + except (KeyError, IndexError, TypeError): + pass + + # Extract grounding metadata URLs + try: + chunks = ( + data["candidates"][0] + .get("groundingMetadata", {}) + .get("groundingChunks", []) + ) + for chunk in chunks: + web = chunk.get("web", {}) + uri = web.get("uri", "") + if uri.startswith("http"): + references.add(uri) + except (KeyError, IndexError, TypeError): + pass + + # Also scan response text for bare URLs + for found_url in re.findall(r"https?://[^\s)]+", "\n".join(combined_text)): + references.add(found_url) + + return { + "text": "\n\n".join(combined_text[:3]), + "references": list(references) + } + + except Exception as e: + self.logger.error("Gemini web search failed: %s", e) + return { + "text": f"Gemini search error: {str(e)}", + "references": [] + } diff --git a/stores/llm/providers/GroqProvider.py b/stores/llm/providers/GroqProvider.py new file mode 100644 index 0000000000000000000000000000000000000000..b8730ad6b03df2a114d0d663c9e8ca7ace81f919 --- /dev/null +++ b/stores/llm/providers/GroqProvider.py @@ -0,0 +1,133 @@ +from stores.llm.LLMInterface import LLMInterface +import logging +import requests +import re +import os + + +class GroqProvider(LLMInterface): + def __init__(self, url: str = None, model: str = None, + default_input_max_characters: int = 1000, + default_generation_max_output_tokens: int = 1000, + default_generation_temperature: float = 0.1, api_key: str = None): + self.url = url or "https://api.groq.com/openai/v1" + self.api_key = api_key or os.getenv("GROQ_API_KEY") + self.model = model + self.generation_model_id = None + + self.embedding_model = None + self.embedding_model_id = None + self.embedding_size = None + + self.default_input_max_characters = default_input_max_characters + self.default_generation_max_output_tokens = default_generation_max_output_tokens + self.default_generation_temperature = default_generation_temperature + self.logger = logging.getLogger(__name__) + + def set_generation_model(self, model_id: str): + if model_id: + self.model = model_id + + def set_embedding_model(self, model_id: str, embedding_size: int): + if model_id: + self.embedding_model = model_id + self.embedding_size = embedding_size + self.embedding_model_id = model_id + + def process_text(self, text: str): + if not text: + return "" + return str(text).strip() + + def generate_text(self, prompt: str, chat_history: list = None, + max_output_tokens: int = None, temperature: float = None): + try: + chat_history = chat_history or [] + clean_prompt = self.process_text(prompt) + + messages = [] + for entry in chat_history: + messages.append({ + "role": entry.get("role", "user"), + "content": entry.get("content", "") + }) + messages.append({"role": "user", "content": clean_prompt}) + + payload = { + "model": self.model, + "messages": messages, + "max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens), + "temperature": float(temperature or self.default_generation_temperature), + } + + url = self.url.rstrip("/") + "/chat/completions" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + resp = requests.post(url, json=payload, headers=headers, timeout=6000) + if resp.status_code != 200: + self.logger.error("Groq generate failed: %s %s", resp.status_code, resp.text) + return None + + data = resp.json() + + try: + generated_text = data["choices"][0]["message"]["content"].strip() + except (KeyError, IndexError, TypeError): + self.logger.error("Unexpected Groq response structure: %s", data) + return None + + if not generated_text: + return None + + usage = data.get("usage", {}) + # Groq exposes x_groq.usage.total_time in seconds + total_time_ms = None + try: + total_time_ms = round(data["x_groq"]["usage"]["total_time"] * 1000, 2) + except (KeyError, TypeError): + pass + + return { + "model": data.get("model"), + "response": generated_text, + "tokens_generated": usage.get("completion_tokens"), + "total_duration_ms": total_time_ms, + "prompt_eval_tokens": usage.get("prompt_tokens"), + } + + except Exception as e: + self.logger.exception("Error in GroqProvider.generate_text: %s", e) + return None + + def embed_text(self, text: str, document_type: str = None): + """Groq does not support embeddings — returns None.""" + self.logger.warning("GroqProvider does not support embeddings.") + return None + + def construct_prompt(self, prompt: str, role: str): + return { + "role": role, + "content": self.process_text(prompt) + } + + def embed_text_batch(self, texts: list[str], batch_size: int = 32): + """Groq does not support embeddings — returns None.""" + self.logger.warning("GroqProvider does not support embeddings.") + return None + + def clean_content(self, text: str) -> str: + text = re.sub(r'\[.*?\]\(.*?\)', '', text) + text = re.sub(r'\[[^\]]*\]', '', text) + text = re.sub(r'\n+', '\n', text).strip() + return text + + def web_search(self, query: str): + """Groq has no native web search — returns a not-supported notice.""" + self.logger.warning("GroqProvider.web_search is not natively supported.") + return { + "text": "Web search is not natively supported by the Groq API.", + "references": [] + } diff --git a/stores/llm/providers/HuggingFaceProvider.py b/stores/llm/providers/HuggingFaceProvider.py new file mode 100644 index 0000000000000000000000000000000000000000..9b2d9c66e49588f99a308f63420bc258245a9850 --- /dev/null +++ b/stores/llm/providers/HuggingFaceProvider.py @@ -0,0 +1,214 @@ +from stores.llm.LLMInterface import LLMInterface +import logging +import requests +import re +import os + + +class HuggingFaceProvider(LLMInterface): + def __init__(self, url: str = None, model: str = None, + default_input_max_characters: int = 1000, + default_generation_max_output_tokens: int = 1000, + default_generation_temperature: float = 0.1, api_key: str = None): + # Supports both Inference API (serverless) and Inference Endpoints (dedicated) + self.url = url or "https://router.huggingface.co" + self.api_key = api_key or os.getenv("HF_API_KEY") + self.model = model + self.generation_model_id = None + + self.embedding_model = None + self.embedding_model_id = None + self.embedding_size = None + + self.default_input_max_characters = default_input_max_characters + self.default_generation_max_output_tokens = default_generation_max_output_tokens + self.default_generation_temperature = default_generation_temperature + self.logger = logging.getLogger(__name__) + + def set_generation_model(self, model_id: str): + if model_id: + self.model = model_id + + def set_embedding_model(self, model_id: str, embedding_size: int): + if model_id: + self.embedding_model = model_id + self.embedding_size = embedding_size + self.embedding_model_id = model_id + + def process_text(self, text: str): + if not text: + return "" + return str(text).strip() + + def generate_text(self, prompt: str, chat_history: list = None, + max_output_tokens: int = None, temperature: float = None): + try: + chat_history = chat_history or [] + clean_prompt = self.process_text(prompt) + + messages = [] + for entry in chat_history: + messages.append({ + "role": entry.get("role", "user"), + "content": entry.get("content", "") + }) + messages.append({"role": "user", "content": clean_prompt}) + + payload = { + "model": self.model, + "messages": messages, + "max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens), + "temperature": float(temperature or self.default_generation_temperature), + } + + # HF Inference API (serverless): /v1/chat/completions (OpenAI-compatible) + url = self.url.rstrip("/") + "/v1/chat/completions" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + resp = requests.post(url, json=payload, headers=headers, timeout=6000) + if resp.status_code != 200: + self.logger.error("HuggingFace generate failed: %s %s", resp.status_code, resp.text) + return None + + data = resp.json() + + try: + generated_text = data["choices"][0]["message"]["content"].strip() + except (KeyError, IndexError, TypeError): + self.logger.error("Unexpected HuggingFace response structure: %s", data) + return None + + if not generated_text: + return None + + usage = data.get("usage", {}) + return { + "model": data.get("model"), + "response": generated_text, + "tokens_generated": usage.get("completion_tokens"), + "total_duration_ms": None, + "prompt_eval_tokens": usage.get("prompt_tokens"), + } + + except Exception as e: + self.logger.exception("Error in HuggingFaceProvider.generate_text: %s", e) + return None + + def embed_text(self, text: str, document_type: str = None): + try: + if not self.embedding_model: + self.logger.error("Embedding model is not set before calling embed_text()") + return None + + clean_text = self.process_text(text) + print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'") + if not clean_text: + return [] + + payload = {"inputs": clean_text} + + # Feature-extraction endpoint per model + url = f"https://router.huggingface.co/hf-inference/models/{self.embedding_model}/pipeline/feature-extraction" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + resp = requests.post(url, json=payload, headers=headers, timeout=200) + if resp.status_code != 200: + print(f"[ERROR] HuggingFace embedding failed: {resp.status_code} {resp.text}") + return None + + data = resp.json() + + # HF returns a nested list: [[vector]] for single input + embedding = None + if isinstance(data, list): + if len(data) > 0 and isinstance(data[0], list): + embedding = data[0] # [[float, ...]] -> [float, ...] + elif len(data) > 0 and isinstance(data[0], float): + embedding = data # [float, ...] already flat + elif isinstance(data, dict) and "embedding" in data: + embedding = data["embedding"] + + if embedding is not None: + print(f"[DEBUG] Embedding length: {len(embedding)}") + return embedding + + print("[WARNING] 'embedding' key not found in response JSON") + return None + + except Exception as e: + print(f"[EXCEPTION] Error in HuggingFaceProvider.embed_text: {e}") + return None + + def construct_prompt(self, prompt: str, role: str): + return { + "role": role, + "content": self.process_text(prompt) + } + + def embed_text_batch(self, texts: list[str], batch_size: int = 32): + self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}") + + if not self.embedding_model: + self.logger.error("Embedding model not set") + return None + + all_embeddings = [] + + url = f"https://router.huggingface.co/hf-inference/models/{self.embedding_model}/pipeline/feature-extraction" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + for i in range(0, len(texts), batch_size): + batch = texts[i:i + batch_size] + clean_batch = [self.process_text(t) for t in batch if t] + + print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size}") + + payload = {"inputs": clean_batch} + + resp = requests.post(url, json=payload, headers=headers, timeout=200) + if resp.status_code != 200: + self.logger.error("HuggingFace embedding failed: %s %s", resp.status_code, resp.text) + return None + + data = resp.json() + + # Batch response: [[vec1], [vec2], ...] or [[f,f,...], [f,f,...]] + embeddings = None + if isinstance(data, list) and len(data) > 0: + if isinstance(data[0], list): + embeddings = data + elif isinstance(data[0], float): + embeddings = [data] # single vector returned flat + + if not embeddings: + self.logger.error("No embeddings returned from HuggingFace") + return None + + self.logger.debug(f"Received {len(embeddings)} embeddings") + all_embeddings.extend(embeddings) + + self.logger.info(f"Total embeddings created: {len(all_embeddings)}") + return all_embeddings + + def clean_content(self, text: str) -> str: + text = re.sub(r'\[.*?\]\(.*?\)', '', text) + text = re.sub(r'\[[^\]]*\]', '', text) + text = re.sub(r'\n+', '\n', text).strip() + return text + + def web_search(self, query: str): + """HuggingFace Inference API has no native web search — returns a not-supported notice.""" + self.logger.warning("HuggingFaceProvider.web_search is not natively supported.") + return { + "text": "Web search is not natively supported by the HuggingFace Inference API.", + "references": [] + } diff --git a/stores/llm/providers/MistralProvider.py b/stores/llm/providers/MistralProvider.py new file mode 100644 index 0000000000000000000000000000000000000000..e113ebbf85fb384cebb80f866da7159ae2566367 --- /dev/null +++ b/stores/llm/providers/MistralProvider.py @@ -0,0 +1,208 @@ +import time + +from stores.llm.LLMInterface import LLMInterface +import logging +import requests +import re +import os + + +class MistralProvider(LLMInterface): + def __init__(self, url: str = None, model: str = None, + default_input_max_characters: int = 1000, + default_generation_max_output_tokens: int = 1000, + default_generation_temperature: float = 0.1, api_key: str = None): + self.url = url or "https://api.mistral.ai/v1" + self.api_key = api_key or os.getenv("MISTRAL_API_KEY") + self.model = model + self.generation_model_id = None + + self.embedding_model = None + self.embedding_model_id = None + self.embedding_size = None + + self.default_input_max_characters = default_input_max_characters + self.default_generation_max_output_tokens = default_generation_max_output_tokens + self.default_generation_temperature = default_generation_temperature + self.logger = logging.getLogger(__name__) + + def set_generation_model(self, model_id: str): + if model_id: + self.model = model_id + + def set_embedding_model(self, model_id: str, embedding_size: int): + if model_id: + self.embedding_model = model_id + self.embedding_size = embedding_size + self.embedding_model_id = model_id + + def process_text(self, text: str): + if not text: + return "" + return str(text).strip() + + def generate_text(self, prompt: str, chat_history: list = None, + max_output_tokens: int = None, temperature: float = None): + try: + chat_history = chat_history or [] + clean_prompt = self.process_text(prompt) + + messages = [] + for entry in chat_history: + messages.append({ + "role": entry.get("role", "user"), + "content": entry.get("content", "") + }) + messages.append({"role": "user", "content": clean_prompt}) + + payload = { + "model": self.model, + "messages": messages, + "max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens), + "temperature": float(temperature or self.default_generation_temperature), + } + + url = self.url.rstrip("/") + "/chat/completions" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + resp = requests.post(url, json=payload, headers=headers, timeout=6000) + if resp.status_code != 200: + self.logger.error("Mistral generate failed: %s %s", resp.status_code, resp.text) + return None + + data = resp.json() + + try: + generated_text = data["choices"][0]["message"]["content"].strip() + except (KeyError, IndexError, TypeError): + self.logger.error("Unexpected Mistral response structure: %s", data) + return None + + if not generated_text: + return None + + usage = data.get("usage", {}) + return { + "model": data.get("model"), + "response": generated_text, + "tokens_generated": usage.get("completion_tokens"), + "total_duration_ms": None, + "prompt_eval_tokens": usage.get("prompt_tokens"), + } + + except Exception as e: + self.logger.exception("Error in MistralProvider.generate_text: %s", e) + return None + + def embed_text(self, text: str, document_type: str = None): + try: + if not self.embedding_model: + self.logger.error("Embedding model is not set before calling embed_text()") + return None + + clean_text = self.process_text(text) + print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'") + if not clean_text: + return [] + + payload = { + "model": self.embedding_model, + "input": [clean_text], + } + + url = self.url.rstrip("/") + "/embeddings" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + resp = requests.post(url, json=payload, headers=headers, timeout=200) + if resp.status_code != 200: + print(f"[ERROR] Mistral embedding failed: {resp.status_code} {resp.text}") + return None + + data = resp.json() + + try: + embedding = data["data"][0]["embedding"] + print(f"[DEBUG] Embedding length: {len(embedding)}") + return embedding + except (KeyError, IndexError, TypeError): + print("[WARNING] 'embedding' key not found in response JSON") + return None + + except Exception as e: + print(f"[EXCEPTION] Error in MistralProvider.embed_text: {e}") + return None + + def construct_prompt(self, prompt: str, role: str): + return { + "role": role, + "content": self.process_text(prompt) + } + + def embed_text_batch(self, texts: list[str], batch_size: int = 32): + self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}") + + if not self.embedding_model: + self.logger.error("Embedding model not set") + return None + + all_embeddings = [] + url = self.url.rstrip("/") + "/embeddings" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + for i in range(0, len(texts), batch_size): + time.sleep(5) + batch = texts[i:i + batch_size] + clean_batch = [self.process_text(t) for t in batch if t] + + print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size}") + + payload = { + "model": self.embedding_model, + "input": clean_batch, + } + + resp = requests.post(url, json=payload, headers=headers, timeout=200) + if resp.status_code != 200: + self.logger.error("Mistral embedding failed: %s %s", resp.status_code, resp.text) + return None + + data = resp.json() + + try: + embeddings = [item["embedding"] for item in data["data"]] + except (KeyError, TypeError): + self.logger.error("No embeddings returned from Mistral") + return None + + if not embeddings: + self.logger.error("No embeddings returned from Mistral") + return None + + self.logger.debug(f"Received {len(embeddings)} embeddings") + all_embeddings.extend(embeddings) + + self.logger.info(f"Total embeddings created: {len(all_embeddings)}") + return all_embeddings + + def clean_content(self, text: str) -> str: + text = re.sub(r'\[.*?\]\(.*?\)', '', text) + text = re.sub(r'\[[^\]]*\]', '', text) + text = re.sub(r'\n+', '\n', text).strip() + return text + + def web_search(self, query: str): + """Mistral has no native web search — returns a not-supported notice.""" + self.logger.warning("MistralProvider.web_search is not natively supported.") + return { + "text": "Web search is not natively supported by the Mistral API.", + "references": [] + } diff --git a/stores/llm/providers/OllamaProvider.py b/stores/llm/providers/OllamaProvider.py new file mode 100644 index 0000000000000000000000000000000000000000..5635bb60e256d6bd1a104b7c9d6576a09a4286fe --- /dev/null +++ b/stores/llm/providers/OllamaProvider.py @@ -0,0 +1,292 @@ +from stores.llm.LLMInterface import LLMInterface +import logging +import requests +import re +import ollama +import os +class OllamaProvider(LLMInterface): + def __init__(self, url: str=None, model: str=None, + default_input_max_characters: int=1000, + default_generation_max_output_tokens: int=1000, + default_generation_temperature: float=0.1, api_key: str=None): + self.url = url or "http://localhost:11434" + self.api_key = api_key or os.getenv("OLLAMA_API_KEY") + self.model = model + self.generation_model_id = None + + self.embedding_model = None + self.embedding_model_id = None + self.embedding_size = None + + self.default_input_max_characters = default_input_max_characters + self.default_generation_max_output_tokens = default_generation_max_output_tokens + self.default_generation_temperature = default_generation_temperature + self.logger = logging.getLogger(__name__) + + def set_generation_model(self, model_id: str): + if model_id: + self.model = model_id + + def set_embedding_model(self, model_id: str, embedding_size: int): + if model_id: + self.embedding_model = model_id + self.embedding_size = embedding_size + self.embedding_model_id = model_id + + def process_text(self, text: str): + if not text: + return "" + return str(text).strip() + + def generate_text(self, prompt: str, chat_history: list = None, + max_output_tokens: int = None, temperature: float = None): + + + try: + chat_history = chat_history or [] # safe handling + clean_prompt = self.process_text(prompt) + + # Build payload with correct Ollama keys + payload = { + "model": self.model, + "prompt": clean_prompt, + "stream": False, + "num_predict": int(max_output_tokens or self.default_generation_max_output_tokens), + "temperature": float(temperature or self.default_generation_temperature), + } + + url = self.url.rstrip("/") + "/api/generate" + headers = {} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + resp = requests.post(url, json=payload, headers=headers, timeout=6000) + if resp.status_code != 200: + self.logger.error("Ollama generate failed: %s %s", resp.status_code, resp.text) + return None + + data = resp.json() + + # Extract final generated text correctly + generated_text = data.get("response", "").strip() + + # If nothing generated, treat as failure + if not generated_text: + return None + + # Return clean JSON instead of raw text + return { + "model": data.get("model"), + "response": generated_text, + "tokens_generated": data.get("eval_count"), + "total_duration_ms": round(data.get("total_duration", 0) / 1e6, 2), + "prompt_eval_tokens": data.get("prompt_eval_count"), + } + + except Exception as e: + self.logger.exception("Error in OllamaProvider.generate_text: %s", e) + return None + + def embed_text(self, text: str, document_type: str = None): + """Return an embedding vector from Ollama.""" + try: + if not self.embedding_model: + self.logger.error("Embedding model is not set before calling embed_text()") + return None + + clean_text = self.process_text(text) + print(f"[DEBUG] Cleaned text: '{clean_text[:20]}...'") + if not clean_text: + return [] + + payload = { + "model": self.embedding_model, + "input": clean_text + } + + url = self.url.rstrip("/") + "/api/embed" + headers = {} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + resp = requests.post(url, json=payload, headers=headers, timeout=400) + if resp.status_code != 200: + print(f"[ERROR] Ollama embedding failed: {resp.status_code} {resp.text}") + return None + + data = resp.json() + + # Expected format: { "embedding": [...] } + if "embedding" in data: + print(f"[DEBUG] Embedding length: {len(data['embedding'])}") + return data["embedding"] + elif "embeddings" in data: + return data["embeddings"][0] + + print("[WARNING] 'embedding' key not found in response JSON") + return None + + except Exception as e: + print(f"[EXCEPTION] Error in OllamaProvider.embed_text: {e}") + return None + + def construct_prompt(self, prompt: str, role: str): + return { + "role": role, + "content": self.process_text(prompt) + } + + # def embed_text_batch(self, texts: list[str], batch_size: int = 32): + + # self.logger.info(f"Embedding {len(texts)} texts using batch_size={batch_size}") + + # if not self.embedding_model: + # self.logger.error("Embedding model not set") + # return None + + # all_embeddings = [] + + # url = self.url.rstrip("/") + "/api/embed" + + # for i in range(0, len(texts), batch_size): + # batch = texts[i:i + batch_size] + + # clean_batch = [self.process_text(t) for t in batch if t] + + # print(f"[EMBED] Embedding {len(texts)} texts using batch_size={batch_size} Progress = {i+batch_size}") + + # payload = { + # "model": self.embedding_model, + # "input": clean_batch + # } + + # resp = requests.post(url, json=payload, timeout=400) + + # if resp.status_code != 200: + # self.logger.error("Ollama embedding failed: %s %s", resp.status_code, resp.text) + # return None + + # data = resp.json() + + # embeddings = data.get("embeddings") + + # if not embeddings: + # self.logger.error("No embeddings returned from Ollama") + # return None + + # self.logger.debug(f"Received {len(embeddings)} embeddings") + + # all_embeddings.extend(embeddings) + + # self.logger.info(f"Total embeddings created: {len(all_embeddings)}") + + # return all_embeddings + + def embed_text_batch(self, texts: list[str], batch_size: int = 64): + """ + Batch embedding for a list of texts, compatible with both /api/embed (new) and /api/embeddings (legacy). + Logs progress and returns a list of embedding vectors. + """ + all_embeddings = [] + + endpoints = ["/api/embed", "/api/embeddings"] + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + + total_texts = len(texts) + self.logger.info(f"Starting batch embedding of {total_texts} texts with batch_size={batch_size}") + + for ep in endpoints: + try: + for i in range(0, total_texts, batch_size): + batch = texts[i:i + batch_size] + clean_batch = [self.process_text(t) for t in batch if t] + + payload = {"model": self.embedding_model} + + if ep == "/api/embed": + payload["input"] = clean_batch + resp = requests.post(self.url.rstrip("/") + ep, json=payload, headers=headers, timeout=400) + if resp.status_code != 200: + self.logger.warning( + "Batch embedding failed at %s: %s %s", ep, resp.status_code, resp.text + ) + continue + + data = resp.json() + embeddings = data.get("embeddings") or ([data.get("embedding")] if "embedding" in data else []) + all_embeddings.extend(embeddings) + + else: + # Legacy endpoint: send individually + for j, t in enumerate(clean_batch): + payload_legacy = {"model": self.embedding_model, "prompt": t} + resp = requests.post(self.url.rstrip("/") + ep, json=payload_legacy, headers=headers, timeout=400) + if resp.status_code != 200: + self.logger.warning( + "Legacy embedding failed at %s: %s %s", ep, resp.status_code, resp.text + ) + continue + + data = resp.json() + if "embedding" in data: + all_embeddings.append(data["embedding"]) + self.logger.info(f"Embedded {i+j+1}/{total_texts} texts using legacy endpoint") + + # Log batch progress + self.logger.info(f"Embedded {min(i+batch_size, total_texts)}/{total_texts} texts using {ep}") + + if all_embeddings: + self.logger.info(f"Finished embedding {len(all_embeddings)}/{total_texts} texts successfully") + break # stop after successful endpoint + + except Exception as e: + self.logger.exception("Batch embedding error at %s: %s", ep, e) + + return all_embeddings + + def clean_content(self, text: str) -> str: + text = re.sub(r'\[.*?\]\(.*?\)', '', text) + text = re.sub(r'\[[^\]]*\]', '', text) + text = re.sub(r'\n+', '\n', text).strip() + return text + + def web_search(self, query: str): + """Use Ollama client to perform web search and return cleaned text + references.""" + try: + # Use your old working Ollama client + OLLAMA_API_KEY = os.getenv("OLLAMA_API_KEY") + ollama_client = ollama.Client(headers={'Authorization': 'Bearer ' + OLLAMA_API_KEY}) + response = ollama_client.web_search(query) + + if not response or "results" not in response or len(response["results"]) == 0: + return { + "text": "No relevant external results found.", + "references": [] + } + + combined_text = [] + references = set() + + for item in response["results"]: + text = self.clean_content(item.content) + combined_text.append(text) + + urls = re.findall(r"https?://[^\s)]+", item.content) + for url in urls: + references.add(url) + + if hasattr(item, "url") and item.url: + references.add(item.url) + + return { + "text": "\n\n".join(combined_text[:3]), + "references": list(references) + } + + except Exception as e: + self.logger.error("Ollama web search failed: %s", e) + return { + "text": f"Ollama search error: {str(e)}", + "references": [] + } diff --git a/stores/llm/providers/OpenAIProvider.py b/stores/llm/providers/OpenAIProvider.py new file mode 100644 index 0000000000000000000000000000000000000000..85d1d37729120fea4d1071c1ecd40521134c4799 --- /dev/null +++ b/stores/llm/providers/OpenAIProvider.py @@ -0,0 +1,102 @@ +from ..LLMInterface import LLMInterface +from ..LLMEnums import OpenAIEnums +from openai import OpenAI +import logging + +class OpenAIProvider(LLMInterface): + def __init__(self, api_key: str, api_url: str = None, + default_input_max_characters: int = 1000, + default_generation_max_output_tokens: int = 1000, + default_generation_temperature: float = 0.1): + self.api_key = api_key + self.api_url = api_url + self.default_input_max_characters = default_input_max_characters + self.default_generation_max_output_tokens = default_generation_max_output_tokens + self.default_generation_temperature = default_generation_temperature + + self.generation_model_id = None + self.embedding_model_id = None + self.embedding_size = None + + self.client = OpenAI(api_key=self.api_key, base_url=self.api_url) + self.logger = logging.getLogger(__name__) + + def set_generation_model(self, model_id: str): + self.generation_model_id = model_id + + def set_embedding_model(self, model_id: str, embedding_size: int): + self.embedding_model_id = model_id + self.embedding_size = embedding_size + + def process_text(self, text: str): + return text[:self.default_input_max_characters].strip() + + def generate_text(self, prompt: str, chat_history: list = None, + max_output_tokens: int = None, temperature: float = None): + if not self.client: + self.logger.error("OpenAI client was not initialized") + return None + + if not self.generation_model_id: + self.logger.error("OpenAI generation model not set") + return None + + max_output_tokens = max_output_tokens or self.default_generation_max_output_tokens + temperature = temperature or self.default_generation_temperature + + messages = chat_history[:] if chat_history else [] + messages.append(self.construct_prompt(prompt, OpenAIEnums.USER.value)) + + try: + response = self.client.chat.completions.create( + model=self.generation_model_id, + messages=messages, + max_completion_tokens=max_output_tokens, + temperature=temperature + ) + + if (not response or not response.choices + or not response.choices[0].message + or not response.choices[0].message.content): + self.logger.error("Invalid OpenAI response format") + return None + + return response.choices[0].message.content + + except Exception as e: + self.logger.exception("Error while generating text with OpenAI: %s", e) + return None + + def embed_text_batch(self, texts: list[str], batch_size: int = 32): + pass + + def embed_text(self, text: str, document_type: str = None): + if not self.client: + self.logger.error("OpenAI client was not initialized") + return None + + if not self.embedding_model_id: + self.logger.error("OpenAI embedding model not set") + return None + + try: + response = self.client.embeddings.create( + model=self.embedding_model_id, + input=text + ) + + if not response or not response.data or not response.data[0].embedding: + self.logger.error("Invalid OpenAI embedding response") + return None + + return response.data[0].embedding + + except Exception as e: + self.logger.exception("Error while embedding text with OpenAI: %s", e) + return None + + def construct_prompt(self, prompt: str, role: str): + return { + "role": role, + "content": self.process_text(prompt) + } diff --git a/stores/llm/providers/OpenRouterProvider.py b/stores/llm/providers/OpenRouterProvider.py new file mode 100644 index 0000000000000000000000000000000000000000..d108f37dee4c2b4864a27b0f67eba726091ae0ec --- /dev/null +++ b/stores/llm/providers/OpenRouterProvider.py @@ -0,0 +1,179 @@ +from stores.llm.LLMInterface import LLMInterface +import logging +import requests +import re +import os + + +class OpenRouterProvider(LLMInterface): + def __init__(self, url: str = None, model: str = None, + default_input_max_characters: int = 1000, + default_generation_max_output_tokens: int = 1000, + default_generation_temperature: float = 0.1, api_key: str = None): + self.url = url or "https://openrouter.ai/api/v1" + self.api_key = api_key or os.getenv("OPENROUTER_API_KEY") + self.model = model + self.generation_model_id = None + + self.embedding_model = None + self.embedding_model_id = None + self.embedding_size = None + + self.default_input_max_characters = default_input_max_characters + self.default_generation_max_output_tokens = default_generation_max_output_tokens + self.default_generation_temperature = default_generation_temperature + self.logger = logging.getLogger(__name__) + + def set_generation_model(self, model_id: str): + if model_id: + self.model = model_id + + def set_embedding_model(self, model_id: str, embedding_size: int): + if model_id: + self.embedding_model = model_id + self.embedding_size = embedding_size + self.embedding_model_id = model_id + + def process_text(self, text: str): + if not text: + return "" + return str(text).strip() + + def generate_text(self, prompt: str, chat_history: list = None, + max_output_tokens: int = None, temperature: float = None): + try: + chat_history = chat_history or [] + clean_prompt = self.process_text(prompt) + + messages = [] + for entry in chat_history: + messages.append({ + "role": entry.get("role", "user"), + "content": entry.get("content", "") + }) + messages.append({"role": "user", "content": clean_prompt}) + + payload = { + "model": self.model, + "messages": messages, + "max_tokens": int(max_output_tokens or self.default_generation_max_output_tokens), + "temperature": float(temperature or self.default_generation_temperature), + } + + url = self.url.rstrip("/") + "/chat/completions" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + # Recommended by OpenRouter for usage tracking + "HTTP-Referer": os.getenv("OPENROUTER_SITE_URL", "http://localhost"), + "X-Title": os.getenv("OPENROUTER_APP_NAME", "LLMApp"), + } + + resp = requests.post(url, json=payload, headers=headers, timeout=6000) + if resp.status_code != 200: + self.logger.error("OpenRouter generate failed: %s %s", resp.status_code, resp.text) + return None + + data = resp.json() + + try: + generated_text = data["choices"][0]["message"]["content"].strip() + except (KeyError, IndexError, TypeError): + self.logger.error("Unexpected OpenRouter response structure: %s", data) + return None + + if not generated_text: + return None + + usage = data.get("usage", {}) + return { + "model": data.get("model"), + "response": generated_text, + "tokens_generated": usage.get("completion_tokens"), + "total_duration_ms": None, + "prompt_eval_tokens": usage.get("prompt_tokens"), + } + + except Exception as e: + self.logger.exception("Error in OpenRouterProvider.generate_text: %s", e) + return None + + def embed_text(self, text: str, document_type: str = None): + """OpenRouter does not support embeddings natively — returns None.""" + self.logger.warning("OpenRouterProvider does not support embeddings.") + return None + + def construct_prompt(self, prompt: str, role: str): + return { + "role": role, + "content": self.process_text(prompt) + } + + def embed_text_batch(self, texts: list[str], batch_size: int = 32): + """OpenRouter does not support embeddings natively — returns None.""" + self.logger.warning("OpenRouterProvider does not support embeddings.") + return None + + def clean_content(self, text: str) -> str: + text = re.sub(r'\[.*?\]\(.*?\)', '', text) + text = re.sub(r'\[[^\]]*\]', '', text) + text = re.sub(r'\n+', '\n', text).strip() + return text + + def web_search(self, query: str): + """ + OpenRouter supports online models (e.g. perplexity/sonar-online) that have + built-in web search. Route the query through one of those models if available, + otherwise fall back to a not-supported notice. + """ + try: + online_model = os.getenv("OPENROUTER_SEARCH_MODEL", "perplexity/sonar-online") + + payload = { + "model": online_model, + "messages": [{"role": "user", "content": query}], + "max_tokens": int(self.default_generation_max_output_tokens), + "temperature": float(self.default_generation_temperature), + } + + url = self.url.rstrip("/") + "/chat/completions" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + "HTTP-Referer": os.getenv("OPENROUTER_SITE_URL", "http://localhost"), + "X-Title": os.getenv("OPENROUTER_APP_NAME", "LLMApp"), + } + + resp = requests.post(url, json=payload, headers=headers, timeout=6000) + if not resp or resp.status_code != 200: + return { + "text": "No relevant external results found.", + "references": [] + } + + data = resp.json() + + combined_text = [] + references = set() + + try: + text_content = data["choices"][0]["message"]["content"] + combined_text.append(self.clean_content(text_content)) + except (KeyError, IndexError, TypeError): + pass + + # Extract any URLs from the response text + for found_url in re.findall(r"https?://[^\s)]+", "\n".join(combined_text)): + references.add(found_url) + + return { + "text": "\n\n".join(combined_text[:3]), + "references": list(references) + } + + except Exception as e: + self.logger.error("OpenRouter web search failed: %s", e) + return { + "text": f"OpenRouter search error: {str(e)}", + "references": [] + } diff --git a/stores/llm/providers/__init__.py b/stores/llm/providers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/stores/vector_store/Qdrant.py b/stores/vector_store/Qdrant.py new file mode 100644 index 0000000000000000000000000000000000000000..b90c12c70a055e5a28c9b0cf391c15b4de3890b7 --- /dev/null +++ b/stores/vector_store/Qdrant.py @@ -0,0 +1,441 @@ +from typing import List, Dict, Any, Literal, Optional, TypedDict +from qdrant_client import QdrantClient +from qdrant_client.models import (VectorParams,Distance,PointStruct,Filter, + FieldCondition,MatchValue,PointIdsList,MatchText,MatchAny) + +import uuid + +MatchType = Literal["eq", "text", "in"] + +class MetaFilter(TypedDict): + field: str # metadata key + op: MatchType # eq | text | in + value: Any + clause: Literal["must", "should", "must_not"] + +# filters = [ +# {"field": "source", "op": "eq", "value": "file.pdf", "clause": "must"}, +# {"field": "course", "op": "in", "value": ["math", "cs"], "clause": "should"}, +# {"field": "bookmark_path", "op": "text", "value": "chapter1", "clause": "must"}, +# ] + +class QdrantStore: + def __init__(self, client: QdrantClient, collection_name: str, vector_size: int): + self.client = client + self.collection_name = collection_name + self.vector_size = vector_size + self.init_collection() + + def init_collection(self): + existing = [c.name for c in self.client.get_collections().collections] + if self.collection_name in existing: + print(f"[INFO] Collection '{self.collection_name}' exists. ") + else: + self.client.create_collection( + collection_name=self.collection_name, + vectors_config=VectorParams(size=self.vector_size, distance=Distance.COSINE) + ) + print(f"[INFO] Created collection '{self.collection_name}' with vector size {self.vector_size}") + + def upsert_embeddings( + self, + client: QdrantClient, + collection: str, + embeddings: List[List[float]], + payloads: List[Dict[str, Any]], + batch_size: int = 64, +): + total = len(embeddings) + + for i in range(0, total, batch_size): + batch_embs = embeddings[i:i + batch_size] + batch_payloads = payloads[i:i + batch_size] + + points = [ + PointStruct( + id=str(uuid.uuid4()), + vector=emb, + payload=payload + ) + for emb, payload in zip(batch_embs, batch_payloads) + if emb is not None + ] + + if points: + self.client.upsert( + collection_name=self.collection_name, + points=points + ) + print(f"[INFO] Inserted batch {i//batch_size + 1} ({len(points)} vectors)") +# def upsert_embeddings( +# self, +# client: QdrantClient, +# collection: str, +# embeddings: List[List[float]], +# payloads: List[Dict[str, Any]], +# batch_size: int = 128, +# ): +# total = len(embeddings) + +# for i in range(0, total, batch_size): + +# batch_embs = embeddings[i:i + batch_size] +# batch_payloads = payloads[i:i + batch_size] + +# points = [] + +# for emb, payload in zip(batch_embs, batch_payloads): +# if emb is None: +# continue + +# points.append( +# PointStruct( +# id=str(uuid.uuid4()), +# vector=emb, +# payload=payload +# ) +# ) + +# if points: +# client.upsert( +# collection_name=collection, +# points=points +# ) + +# print( +# f"[INFO] Inserted batch {i//batch_size + 1} " +# f"({len(points)} vectors)" +# ) + + def delete_by_id(self,client: QdrantClient, collection: str, point_id: str): + try: + point_id_int = int(point_id) + client.delete( + collection_name=collection, + points_selector=PointIdsList(points=[point_id_int]) + ) + print(f"[INFO] Deleted point ID: {point_id}") + except Exception as exc: + print(f"[ERROR] Failed to delete point {point_id}: {exc}") + + def build_qdrant_filter(self,filters: list[MetaFilter] | None) -> Filter | None: + if not filters: + return None + must, should, must_not = [], [], [] + for f in filters: + key = f"metadata.{f['field']}" + op = f["op"] + value = f["value"] + + if op == "eq": + cond = FieldCondition(key=key, match=MatchValue(value=value)) + + elif op == "text": + cond = FieldCondition(key=key, match=MatchText(text=value)) + + elif op == "in": + cond = FieldCondition(key=key, match=MatchAny(any=value)) + + else: + raise ValueError(f"Unsupported op: {op}") + + if f["clause"] == "must": + must.append(cond) + elif f["clause"] == "should": + should.append(cond) + elif f["clause"] == "must_not": + must_not.append(cond) + + return Filter( + must=must or None, + should=should or None, + must_not=must_not or None, + ) + + def query_qdrant( + self, + filters: list[MetaFilter] | None = None, + embedding: List[float] | None = None, + top_k: int = 5, + ): + query_filter = self.build_qdrant_filter(filters) + try: + if embedding is not None: + response = self.client.query_points( + collection_name=self.collection_name, + query=embedding, + query_filter=query_filter, + limit=top_k, + with_payload=True, + ) + + points = response.points + with_score = True + + else: + points, _ = self.client.scroll( + collection_name=self.collection_name, + scroll_filter=query_filter, + limit=top_k, + with_payload=True, + ) + with_score = False + + return [ + { + "id": p.id, + "score": p.score if with_score else None, + "content": p.payload.get("content"), + "metadata": p.payload.get("metadata"), + } + for p in points + ] + + except Exception as e: + print(f"[ERROR] Qdrant query failed: {e}") + return [] + + def get_all_documents(self): + try: + points, _ = self.client.scroll( + collection_name=self.collection_name, + limit=10000, # Adjust as needed + with_payload=True + ) + return [ + { + "id": p.id, + "content": p.payload.get("content"), + "metadata": p.payload.get("metadata"), + } + for p in points + ] + except Exception as e: + print(f"[ERROR] Failed to retrieve all documents: {e}") + return [] + + def get_all_files(self): + try: + points, _ = self.client.scroll( + collection_name=self.collection_name, + limit=10000, # Adjust as needed + with_payload=True + ) + files_usernames_courses = set() + for p in points: + metadata = p.payload.get("metadata", {}) + source = metadata.get("source") + username = metadata.get("username") + course = metadata.get("course") + if source and username and course: + files_usernames_courses.add((source, username, course)) + + return list(files_usernames_courses) + except Exception as e: + print(f"[ERROR] Failed to retrieve all files: {e}") + return [] + + def remove_collection(self): + try: + self.client.delete_collection(collection_name=self.collection_name) + print(f"[INFO] Collection '{self.collection_name}' deleted.") + except Exception as e: + print(f"[ERROR] Failed to delete collection: {e}") + + def list_collections(self): + try: + collections = self.client.get_collections().collections + return [c.name for c in collections] + except Exception as e: + print(f"[ERROR] Failed to list collections: {e}") + return [] + + def remove_points_by_file(self, source_file: str,username: str ,course: str): + try: + response, _ = self.client.scroll( + collection_name=self.collection_name, + scroll_filter=Filter( + must=[ + FieldCondition( + key="metadata.source", + match=MatchValue(value=source_file) + ), + FieldCondition( + key="metadata.username", + match=MatchValue(value=username) + ), + FieldCondition( + key="metadata.course", + match=MatchValue(value=course) + ) + ] + ), + limit=10000, # Adjust as needed + with_payload=False + ) + point_ids = [p.id for p in response] + print(f"[INFO] Found {len(point_ids)} points for file '{source_file}' to delete.") + if point_ids: + self.client.delete( + collection_name=self.collection_name, + points_selector=PointIdsList(points=point_ids) + ) + print(f"[INFO] Deleted {len(point_ids)} points for file '{source_file}'") + return True + else: + print(f"[INFO] No points found for file '{source_file}' to delete.") + return False + except Exception as e: + print(f"[ERROR] Failed to delete points for file '{source_file}': {e}") + return False + + def all_user_files_bookmarks(self, username: str): + try: + raw: dict[str, list[list[str]]] = {} + next_offset = None + + while True: + response, next_offset = self.client.scroll( + collection_name=self.collection_name, + scroll_filter=Filter( + must=[ + FieldCondition( + key="metadata.username", + match=MatchValue(value=username) + ) + ] + ), + limit=100, + offset=next_offset, + with_payload=True, + with_vectors=False + ) + + for p in response: + metadata = p.payload.get("metadata", {}) + source = metadata.get("source") + bookmark_path = metadata.get("bookmark_path") # list like ["Part", "Chapter", "Section"] + + if not source or not isinstance(bookmark_path, list): + continue + + if source not in raw: + raw[source] = [] + + if bookmark_path not in raw[source]: + raw[source].append(bookmark_path) + + if next_offset is None: + break + + # Build nested dict: source → part → chapter → [sections] + result = {} + for source, paths in raw.items(): + nested = {} + for path in paths: + if len(path) == 0: + continue + + part = path[0] + chapter = path[1] if len(path) > 1 else None + section = path[2] if len(path) > 2 else None + + nested.setdefault(part, {}) + + if chapter is None: + # top-level bookmark (e.g. ["Preface"]) + nested[part].setdefault("_sections", []) + continue + + nested[part].setdefault(chapter, []) + + if section and section not in nested[part][chapter]: + nested[part][chapter].append(section) + + result[source] = nested + + print(f"[INFO] Retrieved grouped bookmarks for user '{username}': {result}") + return result + + except Exception as e: + print(f"[ERROR] Failed to retrieve user files and bookmarks: {e}") + return {} + + def retrieve_chunks_by_topic(self,username: str,course: str,topic_embeddings, + refernces: Optional[List[dict]] = None,chunks_per_topic: int = 5): + bookmarked_only = False + metadata_filter = [ + {"field": "username", "op": "eq", "value": username, "clause": "must"}, + {"field": "course", "op": "eq", "value": course, "clause": "must"}, + ] + results = [] + if refernces: + for ref in refernces: + metadata_filter.append({"field": "source", "op": "eq", "value": ref.filename, "clause": "must"}) + bookmarks=ref.bookmarks if ref.bookmarks else [] + #print(bookmarks) + if bookmarks == []: + ten=self.query_qdrant( + filters=metadata_filter, + embedding=topic_embeddings, + top_k=chunks_per_topic) + for one in ten: + results.append(one) + else: + bookmarked_only = True + bookmarks_length = len(bookmarks) + for bookmark in bookmarks: + metadata_filter.append({"field": "bookmark_path", "op": "text", "value": bookmark, "clause": "must"}) + ten=self.query_qdrant( + filters=metadata_filter, + embedding=topic_embeddings, + top_k=chunks_per_topic//bookmarks_length + ) + for one in ten: + results.append(one) + metadata_filter.pop() # remove bookmark filter + metadata_filter.pop() # remove source filter + if not refernces: + ten=self.query_qdrant( + filters=metadata_filter, + embedding=topic_embeddings, + top_k=chunks_per_topic) + for one in ten: + results.append(one) + + + if bookmarked_only: + results = [r for r in results if r.get("metadata", {}).get("bookmark_path")] + else: + bookmarked = [r for r in results if r.get("metadata", {}).get("bookmark_path")] + non_bookmarked = [r for r in results if not r.get("metadata", {}).get("bookmark_path")] + results = [] + + while len(results) < chunks_per_topic and (bookmarked or non_bookmarked): + if bookmarked: + results.append(bookmarked.pop(0)) + if non_bookmarked and len(results) < chunks_per_topic: + results.append(non_bookmarked.pop(0)) + + results = results[:chunks_per_topic] + + return results[:chunks_per_topic] + + def retrieve_for_exam(self,topics: List,username: str,course: str = None, + references: Optional[List[dict]] = None,chunks_per_topic: int = 5): + + exam_chunks = {} + + for topic in topics: + chunks = self.retrieve_chunks_by_topic( + username=username, + course=course, + topic_embeddings=topic[1], # topic[0] = str topic [1] = embeddings + refernces=references, + chunks_per_topic=chunks_per_topic + ) + #print(chunks) + exam_chunks[topic[0]] = chunks + + return exam_chunks + diff --git a/testcolabollama.py b/testcolabollama.py new file mode 100644 index 0000000000000000000000000000000000000000..18e1fdad5a862b8dac218770a592b1719e619b9f --- /dev/null +++ b/testcolabollama.py @@ -0,0 +1,11 @@ +import requests + +url = "https://elwanda-agnathous-tragically.ngrok-free.dev/api/embeddings" +headers = { + "Content-Type": "application/json" +} +data = {"model": "embeddinggemma:latest", "prompt": "Hello world"} + +r = requests.post(url, json=data, headers=headers) +print(r.status_code) +print(r.text) \ No newline at end of file diff --git a/webhook.py b/webhook.py new file mode 100644 index 0000000000000000000000000000000000000000..04c7d41779fe6f27771b6bf69a2d4e6523c3dc1d --- /dev/null +++ b/webhook.py @@ -0,0 +1,191 @@ +from fastapi import FastAPI, Request +from fastapi.responses import HTMLResponse, RedirectResponse +import json +from datetime import datetime + +app = FastAPI() + +# In-memory storage +webhook_logs = [] + + +@app.post("/webhook") +async def webhook(request: Request): + body = await request.body() + + try: + body_json = json.loads(body) + body_pretty = json.dumps(body_json, indent=2) + except Exception: + body_pretty = body.decode("utf-8", errors="ignore") + + webhook_logs.insert(0, { + "time": datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S UTC"), + "headers": dict(request.headers), + "body": body_pretty + }) + + del webhook_logs[50:] + return {"status": "ok"} + + +@app.post("/clear") +async def clear_logs(): + webhook_logs.clear() + return RedirectResponse("/", status_code=303) + + +@app.get("/", response_class=HTMLResponse) +async def dashboard(): + items = "" + for i, log in enumerate(webhook_logs): + items += f""" +
{json.dumps(log["headers"], indent=2)}
+ {log["body"]}
+ No webhooks received yet
"} +