|
|
|
|
|
import os
|
|
|
import fitz
|
|
|
import fasttext
|
|
|
import torch
|
|
|
from PIL import Image
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
|
|
from IndicTransToolkit.processor import IndicProcessor
|
|
|
import google.generativeai as genai
|
|
|
from fastapi import FastAPI
|
|
|
from pydantic import BaseModel
|
|
|
from typing import Optional
|
|
|
import json
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
|
|
|
TRANSLATION_MODEL_REPO_ID = "ai4bharat/indictrans2-indic-en-1B"
|
|
|
OCR_MODEL_ID = "microsoft/trocr-base-printed"
|
|
|
LANGUAGE_TO_TRANSLATE = "mal"
|
|
|
DEVICE = "cpu"
|
|
|
|
|
|
|
|
|
if GEMINI_API_KEY:
|
|
|
genai.configure(api_key=GEMINI_API_KEY)
|
|
|
else:
|
|
|
print("🔴 GEMINI_API_KEY not set.")
|
|
|
|
|
|
|
|
|
translation_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL_REPO_ID, trust_remote_code=True)
|
|
|
translation_model = AutoModelForSeq2SeqLM.from_pretrained(
|
|
|
TRANSLATION_MODEL_REPO_ID, trust_remote_code=True, torch_dtype=torch.float32
|
|
|
).to(DEVICE)
|
|
|
ip = IndicProcessor(inference=True)
|
|
|
|
|
|
ft_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
|
|
|
lang_detect_model = fasttext.load_model(ft_model_path)
|
|
|
|
|
|
ocr_pipeline = pipeline("image-to-text", model=OCR_MODEL_ID, device=-1)
|
|
|
|
|
|
|
|
|
def classify_image_with_gemini(image: Image.Image):
|
|
|
model = genai.GenerativeModel('gemini-1.5-flash-latest')
|
|
|
prompt = "Is this image primarily a text document or an engineering/technical diagram? Answer with only 'document' or 'diagram'."
|
|
|
response = model.generate_content([prompt, image])
|
|
|
classification = response.text.strip().lower()
|
|
|
return "diagram" if "diagram" in classification else "document"
|
|
|
|
|
|
def summarize_diagram_with_gemini(image: Image.Image):
|
|
|
model = genai.GenerativeModel('gemini-1.5-flash-latest')
|
|
|
prompt = "Describe the contents of this technical diagram in a concise summary."
|
|
|
response = model.generate_content([prompt, image])
|
|
|
return response.text.strip()
|
|
|
|
|
|
def extract_text_from_image(path):
|
|
|
image = Image.open(path).convert("RGB")
|
|
|
image_type = classify_image_with_gemini(image)
|
|
|
if image_type == "diagram":
|
|
|
return summarize_diagram_with_gemini(image)
|
|
|
else:
|
|
|
out = ocr_pipeline(image)
|
|
|
return out[0]["generated_text"] if out else ""
|
|
|
|
|
|
def extract_text_from_pdf(path):
|
|
|
doc = fitz.open(path)
|
|
|
return "".join(page.get_text("text") + "\n" for page in doc)
|
|
|
|
|
|
def read_text_from_txt(path):
|
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
|
return f.read()
|
|
|
|
|
|
def detect_language(text_snippet):
|
|
|
s = text_snippet.replace("\n", " ").strip()
|
|
|
if not s: return None
|
|
|
preds = lang_detect_model.predict(s, k=1)
|
|
|
return preds[0][0].split("__")[-1] if preds and preds[0] else None
|
|
|
|
|
|
def translate_chunk(chunk):
|
|
|
batch = ip.preprocess_batch([chunk], src_lang="mal_Mlym", tgt_lang="eng_Latn")
|
|
|
inputs = translation_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
|
|
|
with torch.no_grad():
|
|
|
generated_tokens = translation_model.generate(**inputs, num_beams=5, max_length=512, early_stopping=True)
|
|
|
decoded = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
|
|
return ip.postprocess_batch(decoded, lang="eng_Latn")[0]
|
|
|
|
|
|
def generate_structured_json(text_to_analyze):
|
|
|
model = genai.GenerativeModel('gemini-1.5-flash-latest')
|
|
|
prompt = f"Analyze this document and extract key info as JSON: {text_to_analyze}"
|
|
|
json_schema = {
|
|
|
"type": "OBJECT",
|
|
|
"properties": {
|
|
|
"summary": {"type": "STRING"},
|
|
|
"actions_required": {"type": "ARRAY", "items": {
|
|
|
"type": "OBJECT",
|
|
|
"properties": {"action": {"type": "STRING"}, "priority": {"type": "STRING", "enum": ["High","Medium","Low"]}, "deadline": {"type": "STRING"}, "notes": {"type": "STRING"}},
|
|
|
"required": ["action","priority","deadline","notes"]
|
|
|
}},
|
|
|
"departments_to_notify": {"type": "ARRAY", "items": {"type": "STRING"}},
|
|
|
"cross_document_flags": {"type": "ARRAY", "items": {
|
|
|
"type": "OBJECT",
|
|
|
"properties": {"related_document_type": {"type": "STRING"}, "related_issue": {"type": "STRING"}},
|
|
|
"required": ["related_document_type","related_issue"]
|
|
|
}}
|
|
|
},
|
|
|
"required": ["summary","actions_required","departments_to_notify","cross_document_flags"]
|
|
|
}
|
|
|
generation_config = genai.types.GenerationConfig(response_mime_type="application/json", response_schema=json_schema)
|
|
|
response = model.generate_content(prompt, generation_config=generation_config)
|
|
|
return json.loads(response.text)
|
|
|
|
|
|
def check_relevance_with_gemini(summary_text):
|
|
|
model = genai.GenerativeModel('gemini-1.5-flash-latest')
|
|
|
prompt = f'Is this summary relevant to transportation, infrastructure, railways, or metro systems? Answer "Yes" or "No". Summary: {summary_text}'
|
|
|
response = model.generate_content(prompt)
|
|
|
return "yes" in response.text.strip().lower()
|
|
|
|
|
|
|
|
|
class InputFile(BaseModel):
|
|
|
file_path: str
|
|
|
|
|
|
@app.post("/predict")
|
|
|
def predict(file: InputFile):
|
|
|
if not GEMINI_API_KEY:
|
|
|
return {"error": "Gemini API key not set."}
|
|
|
path = file.file_path
|
|
|
ext = os.path.splitext(path)[1].lower()
|
|
|
|
|
|
|
|
|
if ext == ".pdf":
|
|
|
original_text = extract_text_from_pdf(path)
|
|
|
elif ext == ".txt":
|
|
|
original_text = read_text_from_txt(path)
|
|
|
elif ext in [".png", ".jpg", ".jpeg"]:
|
|
|
original_text = extract_text_from_image(path)
|
|
|
else:
|
|
|
return {"error": "Unsupported file type."}
|
|
|
|
|
|
|
|
|
lines = original_text.split("\n")
|
|
|
translated_lines = []
|
|
|
for ln in lines:
|
|
|
if not ln.strip(): continue
|
|
|
lang = detect_language(ln)
|
|
|
if lang == LANGUAGE_TO_TRANSLATE:
|
|
|
translated_lines.append(translate_chunk(ln))
|
|
|
else:
|
|
|
translated_lines.append(ln)
|
|
|
final_text = "\n".join(translated_lines)
|
|
|
|
|
|
|
|
|
summary_data = generate_structured_json(final_text)
|
|
|
if not summary_data or "summary" not in summary_data:
|
|
|
return {"error": "Failed to generate analysis."}
|
|
|
|
|
|
is_relevant = check_relevance_with_gemini(summary_data["summary"])
|
|
|
if is_relevant:
|
|
|
return summary_data
|
|
|
else:
|
|
|
return {"status": "Not Applicable", "reason": "Document not relevant to KMRL."}
|
|
|
|