File size: 6,765 Bytes
307ef74 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# inference_api.py
import os
import fitz # PyMuPDF
import fasttext
import torch
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from IndicTransToolkit.processor import IndicProcessor
import google.generativeai as genai
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional
import json
app = FastAPI()
# === CONFIGURATION ===
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
TRANSLATION_MODEL_REPO_ID = "ai4bharat/indictrans2-indic-en-1B"
OCR_MODEL_ID = "microsoft/trocr-base-printed"
LANGUAGE_TO_TRANSLATE = "mal"
DEVICE = "cpu"
# --- Configure Gemini ---
if GEMINI_API_KEY:
genai.configure(api_key=GEMINI_API_KEY)
else:
print("🔴 GEMINI_API_KEY not set.")
# --- Load Models ---
translation_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL_REPO_ID, trust_remote_code=True)
translation_model = AutoModelForSeq2SeqLM.from_pretrained(
TRANSLATION_MODEL_REPO_ID, trust_remote_code=True, torch_dtype=torch.float32
).to(DEVICE)
ip = IndicProcessor(inference=True)
ft_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
lang_detect_model = fasttext.load_model(ft_model_path)
ocr_pipeline = pipeline("image-to-text", model=OCR_MODEL_ID, device=-1)
# === HELPER FUNCTIONS ===
def classify_image_with_gemini(image: Image.Image):
model = genai.GenerativeModel('gemini-1.5-flash-latest')
prompt = "Is this image primarily a text document or an engineering/technical diagram? Answer with only 'document' or 'diagram'."
response = model.generate_content([prompt, image])
classification = response.text.strip().lower()
return "diagram" if "diagram" in classification else "document"
def summarize_diagram_with_gemini(image: Image.Image):
model = genai.GenerativeModel('gemini-1.5-flash-latest')
prompt = "Describe the contents of this technical diagram in a concise summary."
response = model.generate_content([prompt, image])
return response.text.strip()
def extract_text_from_image(path):
image = Image.open(path).convert("RGB")
image_type = classify_image_with_gemini(image)
if image_type == "diagram":
return summarize_diagram_with_gemini(image)
else:
out = ocr_pipeline(image)
return out[0]["generated_text"] if out else ""
def extract_text_from_pdf(path):
doc = fitz.open(path)
return "".join(page.get_text("text") + "\n" for page in doc)
def read_text_from_txt(path):
with open(path, "r", encoding="utf-8") as f:
return f.read()
def detect_language(text_snippet):
s = text_snippet.replace("\n", " ").strip()
if not s: return None
preds = lang_detect_model.predict(s, k=1)
return preds[0][0].split("__")[-1] if preds and preds[0] else None
def translate_chunk(chunk):
batch = ip.preprocess_batch([chunk], src_lang="mal_Mlym", tgt_lang="eng_Latn")
inputs = translation_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
with torch.no_grad():
generated_tokens = translation_model.generate(**inputs, num_beams=5, max_length=512, early_stopping=True)
decoded = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
return ip.postprocess_batch(decoded, lang="eng_Latn")[0]
def generate_structured_json(text_to_analyze):
model = genai.GenerativeModel('gemini-1.5-flash-latest')
prompt = f"Analyze this document and extract key info as JSON: {text_to_analyze}"
json_schema = {
"type": "OBJECT",
"properties": {
"summary": {"type": "STRING"},
"actions_required": {"type": "ARRAY", "items": {
"type": "OBJECT",
"properties": {"action": {"type": "STRING"}, "priority": {"type": "STRING", "enum": ["High","Medium","Low"]}, "deadline": {"type": "STRING"}, "notes": {"type": "STRING"}},
"required": ["action","priority","deadline","notes"]
}},
"departments_to_notify": {"type": "ARRAY", "items": {"type": "STRING"}},
"cross_document_flags": {"type": "ARRAY", "items": {
"type": "OBJECT",
"properties": {"related_document_type": {"type": "STRING"}, "related_issue": {"type": "STRING"}},
"required": ["related_document_type","related_issue"]
}}
},
"required": ["summary","actions_required","departments_to_notify","cross_document_flags"]
}
generation_config = genai.types.GenerationConfig(response_mime_type="application/json", response_schema=json_schema)
response = model.generate_content(prompt, generation_config=generation_config)
return json.loads(response.text)
def check_relevance_with_gemini(summary_text):
model = genai.GenerativeModel('gemini-1.5-flash-latest')
prompt = f'Is this summary relevant to transportation, infrastructure, railways, or metro systems? Answer "Yes" or "No". Summary: {summary_text}'
response = model.generate_content(prompt)
return "yes" in response.text.strip().lower()
# === API INPUT SCHEMA ===
class InputFile(BaseModel):
file_path: str
@app.post("/predict")
def predict(file: InputFile):
if not GEMINI_API_KEY:
return {"error": "Gemini API key not set."}
path = file.file_path
ext = os.path.splitext(path)[1].lower()
# Phase 1: Extract text
if ext == ".pdf":
original_text = extract_text_from_pdf(path)
elif ext == ".txt":
original_text = read_text_from_txt(path)
elif ext in [".png", ".jpg", ".jpeg"]:
original_text = extract_text_from_image(path)
else:
return {"error": "Unsupported file type."}
# Phase 2: Translate Malayalam if detected
lines = original_text.split("\n")
translated_lines = []
for ln in lines:
if not ln.strip(): continue
lang = detect_language(ln)
if lang == LANGUAGE_TO_TRANSLATE:
translated_lines.append(translate_chunk(ln))
else:
translated_lines.append(ln)
final_text = "\n".join(translated_lines)
# Phase 3: Gemini analysis
summary_data = generate_structured_json(final_text)
if not summary_data or "summary" not in summary_data:
return {"error": "Failed to generate analysis."}
is_relevant = check_relevance_with_gemini(summary_data["summary"])
if is_relevant:
return summary_data
else:
return {"status": "Not Applicable", "reason": "Document not relevant to KMRL."}
|