# inference_api.py import os import fitz # PyMuPDF import fasttext import torch from PIL import Image from huggingface_hub import hf_hub_download from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline from IndicTransToolkit.processor import IndicProcessor import google.generativeai as genai from fastapi import FastAPI from pydantic import BaseModel from typing import Optional import json app = FastAPI() # === CONFIGURATION === GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY") TRANSLATION_MODEL_REPO_ID = "ai4bharat/indictrans2-indic-en-1B" OCR_MODEL_ID = "microsoft/trocr-base-printed" LANGUAGE_TO_TRANSLATE = "mal" DEVICE = "cpu" # --- Configure Gemini --- if GEMINI_API_KEY: genai.configure(api_key=GEMINI_API_KEY) else: print("🔴 GEMINI_API_KEY not set.") # --- Load Models --- translation_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL_REPO_ID, trust_remote_code=True) translation_model = AutoModelForSeq2SeqLM.from_pretrained( TRANSLATION_MODEL_REPO_ID, trust_remote_code=True, torch_dtype=torch.float32 ).to(DEVICE) ip = IndicProcessor(inference=True) ft_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin") lang_detect_model = fasttext.load_model(ft_model_path) ocr_pipeline = pipeline("image-to-text", model=OCR_MODEL_ID, device=-1) # === HELPER FUNCTIONS === def classify_image_with_gemini(image: Image.Image): model = genai.GenerativeModel('gemini-1.5-flash-latest') prompt = "Is this image primarily a text document or an engineering/technical diagram? Answer with only 'document' or 'diagram'." response = model.generate_content([prompt, image]) classification = response.text.strip().lower() return "diagram" if "diagram" in classification else "document" def summarize_diagram_with_gemini(image: Image.Image): model = genai.GenerativeModel('gemini-1.5-flash-latest') prompt = "Describe the contents of this technical diagram in a concise summary." response = model.generate_content([prompt, image]) return response.text.strip() def extract_text_from_image(path): image = Image.open(path).convert("RGB") image_type = classify_image_with_gemini(image) if image_type == "diagram": return summarize_diagram_with_gemini(image) else: out = ocr_pipeline(image) return out[0]["generated_text"] if out else "" def extract_text_from_pdf(path): doc = fitz.open(path) return "".join(page.get_text("text") + "\n" for page in doc) def read_text_from_txt(path): with open(path, "r", encoding="utf-8") as f: return f.read() def detect_language(text_snippet): s = text_snippet.replace("\n", " ").strip() if not s: return None preds = lang_detect_model.predict(s, k=1) return preds[0][0].split("__")[-1] if preds and preds[0] else None def translate_chunk(chunk): batch = ip.preprocess_batch([chunk], src_lang="mal_Mlym", tgt_lang="eng_Latn") inputs = translation_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE) with torch.no_grad(): generated_tokens = translation_model.generate(**inputs, num_beams=5, max_length=512, early_stopping=True) decoded = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) return ip.postprocess_batch(decoded, lang="eng_Latn")[0] def generate_structured_json(text_to_analyze): model = genai.GenerativeModel('gemini-1.5-flash-latest') prompt = f"Analyze this document and extract key info as JSON: {text_to_analyze}" json_schema = { "type": "OBJECT", "properties": { "summary": {"type": "STRING"}, "actions_required": {"type": "ARRAY", "items": { "type": "OBJECT", "properties": {"action": {"type": "STRING"}, "priority": {"type": "STRING", "enum": ["High","Medium","Low"]}, "deadline": {"type": "STRING"}, "notes": {"type": "STRING"}}, "required": ["action","priority","deadline","notes"] }}, "departments_to_notify": {"type": "ARRAY", "items": {"type": "STRING"}}, "cross_document_flags": {"type": "ARRAY", "items": { "type": "OBJECT", "properties": {"related_document_type": {"type": "STRING"}, "related_issue": {"type": "STRING"}}, "required": ["related_document_type","related_issue"] }} }, "required": ["summary","actions_required","departments_to_notify","cross_document_flags"] } generation_config = genai.types.GenerationConfig(response_mime_type="application/json", response_schema=json_schema) response = model.generate_content(prompt, generation_config=generation_config) return json.loads(response.text) def check_relevance_with_gemini(summary_text): model = genai.GenerativeModel('gemini-1.5-flash-latest') prompt = f'Is this summary relevant to transportation, infrastructure, railways, or metro systems? Answer "Yes" or "No". Summary: {summary_text}' response = model.generate_content(prompt) return "yes" in response.text.strip().lower() # === API INPUT SCHEMA === class InputFile(BaseModel): file_path: str @app.post("/predict") def predict(file: InputFile): if not GEMINI_API_KEY: return {"error": "Gemini API key not set."} path = file.file_path ext = os.path.splitext(path)[1].lower() # Phase 1: Extract text if ext == ".pdf": original_text = extract_text_from_pdf(path) elif ext == ".txt": original_text = read_text_from_txt(path) elif ext in [".png", ".jpg", ".jpeg"]: original_text = extract_text_from_image(path) else: return {"error": "Unsupported file type."} # Phase 2: Translate Malayalam if detected lines = original_text.split("\n") translated_lines = [] for ln in lines: if not ln.strip(): continue lang = detect_language(ln) if lang == LANGUAGE_TO_TRANSLATE: translated_lines.append(translate_chunk(ln)) else: translated_lines.append(ln) final_text = "\n".join(translated_lines) # Phase 3: Gemini analysis summary_data = generate_structured_json(final_text) if not summary_data or "summary" not in summary_data: return {"error": "Failed to generate analysis."} is_relevant = check_relevance_with_gemini(summary_data["summary"]) if is_relevant: return summary_data else: return {"status": "Not Applicable", "reason": "Document not relevant to KMRL."}