| import os |
| from transformers import pipeline |
| import torch |
| import nltk |
| from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction |
| import fitz |
| from sentence_transformers import SentenceTransformer |
| import faiss |
| import numpy as np |
| import pickle |
| import re |
| import logging |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel |
| import uvicorn |
| import asyncio |
| from config import ( |
| ALL_FILES, |
| MATH_FILES, |
| SCIENCE_FILES, |
| DATA_DIR, |
| DOCUMENTS_PATH, |
| FAISS_INDEX_PATH, |
| HUGGINGFACE_TOKEN, |
| MODEL_ID |
| ) |
|
|
| app = FastAPI(title="Swahili Content Generation API") |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| class PromptRequest(BaseModel): |
| prompt: str |
|
|
| class ContentRequest(BaseModel): |
| grade: int |
| subject: str |
| topic: str |
| style: str = "normal" |
| content_length: str = "medium" |
|
|
| TOPIC_KEYWORDS = { |
| |
| 'mazingira g3.pdf': ['mazingira'], |
| 'nishati g3.pdf': ['nishati'], |
| 'maada g3.pdf': ['maada'], |
| 'mawasiliano g3.pdf': ['mawasiliano'], |
| 'usafi g3.pdf': ['usafi'], |
| 'vipimo g3.pdf': ['vipimo-s'], |
| 'mlo g3.pdf': ['mlo'], |
| 'mfumo g3.pdf': ['mfumo'], |
| 'maambukizi g3.pdf': ['maambukizi'], |
| 'huduma g3.pdf': ['huduma'], |
| 'vifaa g3.pdf': ['vifaa'], |
| |
| |
| 'kinga ya mwili g4.txt': ['kinga'], |
| 'magonjwa g4.txt': ['magonjwa'], |
| 'majaribio ya kisayansi g4.txt': ['majaribio'], |
| 'maji g4.txt': ['maji'], |
| 'ukimwi g4.txt': ['ukimwi'], |
| 'huduma g4.txt': ['huduma-g4'], |
| 'mazingira g4.txt': ['mazingira-g4'], |
| 'matumizi ya nishati g4.txt': ['matumizi-nishati-g4'], |
| 'nishati g4.txt': ['nishati-g4'], |
| 'mfumo g4.txt': ['mfumo-g4'], |
| 'mawasiliano g4.txt': ['mawasiliano-g4'], |
| |
| |
| 'namba g3.txt': ['namba'], |
| 'mpangilio g3.txt': ['mpangilio'], |
| 'matendo katika namba g3.txt': ['matendo'], |
| 'kutambua sehemu g3.txt': ['sehemu'], |
| 'kutambua maumbo g3.txt': ['maumbo'], |
| 'vipimo g3.txt': ['vipimo'], |
| 'fedha g3.txt': ['fedha'], |
| 'takwimu kwa picha g3.txt': ['takwimu'], |
| |
| |
| 'kugawanya namba g4.txt': ['kugawanya'], |
| 'kujumlisha namba g4.txt': ['kujumlisha'], |
| 'kuzidisha namba g4.txt': ['kuzidisha'], |
| 'namba nzima g4.txt': ['namba-g4'], |
| 'namba za kirumi g4.txt': ['kirumi'], |
| 'wakati g4.txt': ['wakati'], |
| 'mpangilio g4.txt': ['mpangilio-g4'], |
| 'vipimo g4.txt': ['vipimo-g4'], |
| 'takwimu g4.txt': ['takwimu-g4'], |
| 'kutoa namba g4.txt': ['kutoa'], |
| 'fedha g4.txt': ['fedha-g4'], |
| 'sehemu g4.txt': ['sehemu-g4'], |
| 'maumbo g4.txt': ['maumbo-g4'] |
| } |
|
|
| def preprocess_pdf_text(text): |
| words_to_remove = ['FOR', 'ONLINE', 'USE', 'ONLY', 'DO', 'NOT', 'DUPLICATE', 'SAYANSI', 'STD', 'PM'] |
| pattern = r'\b(?:' + '|'.join(map(re.escape, words_to_remove)) + r')\b' |
| text = re.sub(pattern, '', text, flags=re.IGNORECASE) |
|
|
| text = ' '.join(text.split()) |
| text = re.sub(r'[^\w\s\.\,\?\!\'\"àèìòùÀÈÌÒÙáéíóúÁÉÍÓÚâêîôûÂÊÎÔÛãẽĩõũÃẼĨÕŨ]', ' ', text) |
| text = ' '.join(text.split()) |
| return text |
|
|
| def extract_text_from_file(file_path): |
| if file_path.lower().endswith('.pdf'): |
| return extract_text_from_pdf(file_path) |
| elif file_path.lower().endswith('.txt'): |
| try: |
| with open(file_path, 'r', encoding='utf-8') as file: |
| text = file.read() |
| return text.strip() |
| except Exception as e: |
| logging.error(f"Error reading text file {file_path}: {str(e)}") |
| return "" |
| else: |
| logging.error(f"Unsupported file type for {file_path}") |
| return "" |
|
|
| def extract_text_from_pdf(pdf_path): |
| text = "" |
| with fitz.open(pdf_path) as doc: |
| for page_num, page in enumerate(doc): |
| try: |
| blocks = page.get_text("blocks") |
| page_text = "\n".join(block[4] for block in blocks) |
| cleaned_text = preprocess_pdf_text(page_text) |
| text += cleaned_text + "\n" |
| |
| except Exception as e: |
| logging.error(f"Error processing page {page_num + 1}: {str(e)}") |
| continue |
|
|
| return text.strip() |
|
|
| def split_text_into_chunks(text, source_file, chunk_size=500, overlap=50): |
| |
| text = text.strip().replace('\n', ' ').replace(' ', ' ') |
|
|
| |
| filename = os.path.basename(source_file) |
| keywords = TOPIC_KEYWORDS.get(filename, []) |
|
|
| |
| sentences = nltk.sent_tokenize(text) |
| chunks = [] |
| current_chunk = [] |
| current_size = 0 |
|
|
| for sentence in sentences: |
| sentence_words = len(sentence.split()) |
|
|
| if current_size + sentence_words > chunk_size: |
| if current_chunk: |
| |
| chunk_text = ' '.join(current_chunk) |
| |
| chunk_info = { |
| 'text': chunk_text, |
| 'source': filename, |
| 'keywords': keywords |
| } |
| |
| chunks.append(chunk_info) |
|
|
| |
| overlap_size = 0 |
| overlap_chunk = [] |
| for s in reversed(current_chunk): |
| if overlap_size + len(s.split()) <= overlap: |
| overlap_chunk.insert(0, s) |
| overlap_size += len(s.split()) |
| else: |
| break |
|
|
| current_chunk = overlap_chunk |
| current_size = overlap_size |
|
|
| current_chunk.append(sentence) |
| current_size += sentence_words |
|
|
| if current_chunk: |
| chunk_text = ' '.join(current_chunk) |
| chunks.append({ |
| 'text': chunk_text, |
| 'source': filename, |
| 'keywords': keywords |
| }) |
|
|
| return chunks |
|
|
| def create_faiss_index(texts, embedding_model): |
| doc_embeddings = embedding_model.encode(texts) |
| index = faiss.IndexFlatL2(doc_embeddings.shape[1]) |
| index.add(np.array(doc_embeddings)) |
| return index |
|
|
| def retrieve_documents(query, index, embedding_model, documents, top_k=5): |
| query_lower = query.lower() |
| target_topic = None |
|
|
| |
| for filename, keywords in TOPIC_KEYWORDS.items(): |
| if keywords[0] == query_lower: |
| target_topic = filename |
| break |
| |
| |
| query_embedding = embedding_model.encode([query]) |
| distances, indices = index.search(query_embedding, top_k * 3) |
|
|
| |
| topic_docs = [] |
|
|
| for idx in indices[0]: |
| doc = documents[idx] |
| if doc['source'] == target_topic: |
| |
| if not any(existing.get('text', '') == doc['text'] for existing in topic_docs): |
| topic_docs.append(doc) |
|
|
| if len(topic_docs) >= top_k: |
| break |
|
|
| final_content = "\n\n".join(doc['text'] for doc in topic_docs[:top_k]) |
| logger.info(f"Retrieved content from: {target_topic}") |
| return final_content |
|
|
| def calculate_bleu(reference, candidate): |
| """ |
| Calculate BLEU score between reference and candidate texts. |
| """ |
| if isinstance(reference, list): |
| reference = " ".join(reference) |
| if isinstance(candidate, list): |
| candidate = " ".join(candidate) |
|
|
| reference_tokens = [reference.split()] |
| candidate_tokens = candidate.split() |
| smoothing = SmoothingFunction().method1 |
| return sentence_bleu(reference_tokens, candidate_tokens, smoothing_function=smoothing) |
|
|
| def get_topic_files(grade: int, subject: str, topic: str): |
| |
| topic_lower = topic.lower() |
| |
| |
| file_list = MATH_FILES if subject.lower() == "math" else SCIENCE_FILES |
| |
| |
| matching_files = [] |
| for file in file_list: |
| if f"g{grade}" in file.lower(): |
| filename = os.path.basename(file) |
| if filename in TOPIC_KEYWORDS: |
| keywords = TOPIC_KEYWORDS[filename] |
| if topic_lower == keywords[0]: |
| matching_files.append(file) |
| |
| return matching_files |
|
|
| def generate_response_with_rag(prompt, index, embedding_model, documents, settings): |
| |
| retrieved_context = retrieve_documents(prompt, index, embedding_model, documents) |
|
|
| |
| logger.info("Context sent to model:") |
| logger.info("-" * 50) |
| logger.info(retrieved_context) |
| logger.info("-" * 50) |
|
|
| style_instructions = { |
| "simple": "Provide clear and easy-to-understand answers using common words and short sentences. Explain concepts as if talking to a young student.", |
| "creative": "Give creative and engaging answers, using real-life examples and illustrations to make the content interesting and memorable.", |
| "normal": "" |
| } |
|
|
| content_length_instructions = { |
| "short": "Keep your response brief and concise. Focus on the most essential points only. Provide only 2-3 subtopics, 1-2 activities, and 3-4 practice questions.", |
| "medium": "", |
| "long": "Provide a comprehensive and detailed explanation. Include more examples, detailed explanations for each subtopic (at least 4-5 subtopics), 3-4 activities, and 6-8 practice questions." |
| } |
|
|
| instruction = style_instructions.get(settings.get("style", "normal"), "") |
| length_instruction = content_length_instructions.get(settings.get("content_length", "medium"), "") |
|
|
| |
| system_prompt = f""" |
| Explain the topic of "{settings['topic']}" in detail following this structure: |
| 1. Summary: Briefly explain what the student will learn in this topic (5-6 sentences). |
| 2. Introduction to the topic: Provide background information about the topic before breaking it down into subtopics. |
| 3. Subtopics: Explain each subtopic in detail, providing real-life examples where necessary. For each subtopic, Describe images that could help explain the topic in detail using text instead of actual images. Dont make your description exceed more than 4 words. |
| Use this format: [Picture: Image description]. Dont provide more than 3 [Picture: Image description]. Dont make the Image description exceed more than 4 words. |
| 4. Activities: After each subtopic, provide small exercises or activities that the student can do to enhance understanding (Activities). |
| 5. Practice questions: Provide 6-8 questions related to the topic to reinforce the student's understanding. |
| |
| **Respond to all questions and instructions in Swahili. Also keep all headings in Swahili. Example: Use Maswali instead of Questions. Dont |
| use words like subtopic, activity etc, instead use their swahili equivalents.** |
| |
| Please follow these instructions: |
| * Fully written-out numbers (e.g., "watoto watatu" instead of "3") |
| * No abbreviations (e.g., "Shule ya Msingi" not "Sh. ya Msingi") |
| * Swahili-only words or phonetic "Swahilicized" versions for foreign terms |
| * Minimal punctuation—just commas and full stops |
| * Quotations, brackets, symbols removed |
| * Natural, spoken tone fit for young learners |
| |
| IMPORTANT: Make sure at least two [Picture: Image description] are included in the response. |
| |
| {instruction} |
| {length_instruction} |
| |
| Context: |
| {retrieved_context} |
| """ |
|
|
| |
| messages = [{"role": "system", "content": system_prompt}] |
| outputs = app.state.pipe(messages, max_new_tokens=2000) |
|
|
| try: |
| |
| if not outputs or len(outputs) == 0: |
| logger.error("No output generated") |
| return { |
| "content": "Failed to generate response", |
| "context": retrieved_context |
| } |
| |
|
|
| generated_messages = outputs[0]['generated_text'] |
| if isinstance(generated_messages, list): |
| |
| for message in generated_messages: |
| if message.get('role') == 'assistant': |
| response_content = message.get('content', '') |
| break |
| else: |
| logger.error("No assistant response found in messages") |
| return { |
| "content": "Failed to generate response", |
| "context": retrieved_context |
| } |
| else: |
| response_content = generated_messages |
|
|
| if not response_content: |
| logger.error("Empty response content") |
| return { |
| "content": "Failed to generate response", |
| "context": retrieved_context |
| } |
|
|
| |
| response_content = response_content.strip() |
| |
| |
| paragraphs = [p.strip() for p in response_content.split('\n\n') if p.strip()] |
| |
| |
| formatted_paragraphs = [] |
| for paragraph in paragraphs: |
| |
| |
| if len(paragraph) > 100 and '\n' not in paragraph: |
| sentences = [s.strip() for s in nltk.sent_tokenize(paragraph)] |
| formatted_paragraphs.append('\n'.join(sentences)) |
| else: |
| formatted_paragraphs.append(paragraph) |
| |
| |
| response_content = '\n\n'.join(formatted_paragraphs) |
| response_content = response_content.replace('\n', '<br>') |
|
|
| return { |
| "content": response_content, |
| "context": retrieved_context |
| } |
|
|
| except Exception as e: |
| logger.error(f"Error processing response: {e}") |
| logger.error(f"Raw output: {outputs}") |
| return { |
| "content": "Error processing response", |
| "context": retrieved_context |
| } |
|
|
| async def load_or_create_index(): |
| embedding_model = SentenceTransformer("all-MiniLM-L6-v2") |
| os.makedirs(DATA_DIR, exist_ok=True) |
| os.makedirs(os.path.dirname(FAISS_INDEX_PATH), exist_ok=True) |
|
|
| try: |
| with open(DOCUMENTS_PATH, 'rb') as f: |
| documents = pickle.load(f) |
| index = faiss.read_index(FAISS_INDEX_PATH) |
| print("FAISS index and documents loaded successfully.") |
| return index, documents, embedding_model |
| except FileNotFoundError: |
| print("Index and documents not found. Proceeding to create them.") |
| documents = [] |
| |
| |
| files_found = False |
| for file_path in ALL_FILES: |
| if not os.path.exists(file_path): |
| logger.warning(f"File not found: {file_path}") |
| continue |
| |
| filename = os.path.basename(file_path) |
| logging.info(f"Processing {filename}") |
| text = extract_text_from_file(file_path) |
| |
| if text: |
| files_found = True |
| chunks = split_text_into_chunks(text, filename) |
| documents.extend(chunks) |
| await asyncio.sleep(0) |
| |
| if not files_found: |
| raise Exception(f"No valid text or PDF files found in the specified paths") |
|
|
| texts = [doc['text'] for doc in documents] |
| index = create_faiss_index(texts, embedding_model) |
|
|
| os.makedirs(os.path.dirname(DOCUMENTS_PATH), exist_ok=True) |
|
|
| |
| with open(DOCUMENTS_PATH, 'wb') as f: |
| pickle.dump(documents, f) |
| faiss.write_index(index, FAISS_INDEX_PATH) |
| print("FAISS index and documents saved successfully.") |
| |
| return index, documents, embedding_model |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| @app.on_event("startup") |
| async def startup_event(): |
| """Initialize the application on startup.""" |
| logger = logging.getLogger(__name__) |
| logger.info("Starting application initialization...") |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| logger.info(f"Using device: {device}") |
|
|
| if device == "cpu": |
| logger.warning("GPU not detected. Model will run slower on CPU.") |
|
|
| |
| nltk_data_dir = os.environ.get('NLTK_DATA', os.path.join(os.path.expanduser('~'), 'nltk_data')) |
| os.makedirs(nltk_data_dir, exist_ok=True) |
| |
| |
| logger.info("Downloading NLTK data...") |
| try: |
| |
| import nltk.data |
| try: |
| nltk.data.find('tokenizers/punkt', paths=[nltk_data_dir]) |
| logger.info("NLTK punkt already downloaded") |
| except LookupError: |
| await asyncio.to_thread(nltk.download, 'punkt', download_dir=nltk_data_dir, quiet=True) |
| |
| try: |
| nltk.data.find('tokenizers/punkt_tab', paths=[nltk_data_dir]) |
| logger.info("NLTK punkt_tab already downloaded") |
| except LookupError: |
| await asyncio.to_thread(nltk.download, 'punkt_tab', download_dir=nltk_data_dir, quiet=True) |
| except Exception as e: |
| logger.error(f"Error handling NLTK data: {str(e)}") |
| raise Exception(f"Failed to initialize application: {str(e)}") |
|
|
| |
| try: |
| app.state.pipe = pipeline( |
| "text-generation", |
| model=MODEL_ID, |
| trust_remote_code=True, |
| token=HUGGINGFACE_TOKEN, |
| device_map="auto", |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32 |
| ) |
| |
| faiss_index, documents, embedding_model = await load_or_create_index() |
| |
| |
| app.state.faiss_index = faiss_index |
| app.state.documents = documents |
| app.state.embedding_model = embedding_model |
| |
| logger.info("Application initialization completed successfully") |
| except Exception as e: |
| logger.error(f"Error initializing application: {str(e)}") |
| raise Exception(f"Failed to initialize application: {str(e)}") |
|
|
| @app.post("/generate") |
| async def generate_content(request: ContentRequest): |
| try: |
| logger.info(f"Generating content for grade {request.grade}, subject {request.subject}, topic {request.topic}") |
| |
| |
| if request.grade not in [3, 4]: |
| raise HTTPException(status_code=400, detail="Invalid grade level. Must be 3 or 4") |
| |
| if request.subject.lower() not in ["math", "science"]: |
| raise HTTPException(status_code=400, detail="Invalid subject. Must be 'math' or 'science'") |
| |
| if request.style not in ["normal", "simple", "creative"]: |
| raise HTTPException(status_code=400, detail="Invalid style. Must be 'normal', 'simple', or 'creative'") |
| |
| if request.content_length not in ["short", "medium", "long"]: |
| raise HTTPException(status_code=400, detail="Invalid content length. Must be 'short', 'medium', or 'long'") |
| |
| |
| topic_files = get_topic_files(request.grade, request.subject, request.topic) |
| if not topic_files: |
| raise HTTPException(status_code=404, detail="Topic not found for specified grade and subject") |
| |
| |
| settings = { |
| "style": request.style, |
| "topic": request.topic, |
| "grade": request.grade, |
| "subject": request.subject, |
| "content_length": request.content_length |
| } |
| |
| response = generate_response_with_rag( |
| request.topic, |
| app.state.faiss_index, |
| app.state.embedding_model, |
| app.state.documents, |
| settings |
| ) |
| |
| logger.info("Content generated successfully") |
| return {"response": response['content']} |
| |
| except Exception as e: |
| logger.error(f"Error generating response: {str(e)}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.get("/health") |
| async def health_check(): |
| try: |
| |
| if not hasattr(app.state, "pipe"): |
| return {"status": "starting", "message": "Model is still loading"} |
| return {"status": "healthy"} |
| except Exception as e: |
| logger.error(f"Health check failed: {str(e)}") |
| raise HTTPException(status_code=500, detail="Internal server error") |
|
|
| if __name__ == "__main__": |
| try: |
| logger.info("Starting FastAPI server...") |
| uvicorn.run(app, host="0.0.0.0", port=8080, log_level="info") |
| except Exception as e: |
| logger.error(f"Application failed to start: {str(e)}") |
| raise |
|
|