File size: 6,765 Bytes
307ef74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# inference_api.py
import os
import fitz  # PyMuPDF
import fasttext
import torch
from PIL import Image
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from IndicTransToolkit.processor import IndicProcessor
import google.generativeai as genai
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional
import json

app = FastAPI()

# === CONFIGURATION ===
GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
TRANSLATION_MODEL_REPO_ID = "ai4bharat/indictrans2-indic-en-1B"
OCR_MODEL_ID = "microsoft/trocr-base-printed"
LANGUAGE_TO_TRANSLATE = "mal"
DEVICE = "cpu"

# --- Configure Gemini ---
if GEMINI_API_KEY:
    genai.configure(api_key=GEMINI_API_KEY)
else:
    print("🔴 GEMINI_API_KEY not set.")

# --- Load Models ---
translation_tokenizer = AutoTokenizer.from_pretrained(TRANSLATION_MODEL_REPO_ID, trust_remote_code=True)
translation_model = AutoModelForSeq2SeqLM.from_pretrained(
    TRANSLATION_MODEL_REPO_ID, trust_remote_code=True, torch_dtype=torch.float32
).to(DEVICE)
ip = IndicProcessor(inference=True)

ft_model_path = hf_hub_download(repo_id="facebook/fasttext-language-identification", filename="model.bin")
lang_detect_model = fasttext.load_model(ft_model_path)

ocr_pipeline = pipeline("image-to-text", model=OCR_MODEL_ID, device=-1)

# === HELPER FUNCTIONS ===
def classify_image_with_gemini(image: Image.Image):
    model = genai.GenerativeModel('gemini-1.5-flash-latest')
    prompt = "Is this image primarily a text document or an engineering/technical diagram? Answer with only 'document' or 'diagram'."
    response = model.generate_content([prompt, image])
    classification = response.text.strip().lower()
    return "diagram" if "diagram" in classification else "document"

def summarize_diagram_with_gemini(image: Image.Image):
    model = genai.GenerativeModel('gemini-1.5-flash-latest')
    prompt = "Describe the contents of this technical diagram in a concise summary."
    response = model.generate_content([prompt, image])
    return response.text.strip()

def extract_text_from_image(path):
    image = Image.open(path).convert("RGB")
    image_type = classify_image_with_gemini(image)
    if image_type == "diagram":
        return summarize_diagram_with_gemini(image)
    else:
        out = ocr_pipeline(image)
        return out[0]["generated_text"] if out else ""

def extract_text_from_pdf(path):
    doc = fitz.open(path)
    return "".join(page.get_text("text") + "\n" for page in doc)

def read_text_from_txt(path):
    with open(path, "r", encoding="utf-8") as f:
        return f.read()

def detect_language(text_snippet):
    s = text_snippet.replace("\n", " ").strip()
    if not s: return None
    preds = lang_detect_model.predict(s, k=1)
    return preds[0][0].split("__")[-1] if preds and preds[0] else None

def translate_chunk(chunk):
    batch = ip.preprocess_batch([chunk], src_lang="mal_Mlym", tgt_lang="eng_Latn")
    inputs = translation_tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE)
    with torch.no_grad():
        generated_tokens = translation_model.generate(**inputs, num_beams=5, max_length=512, early_stopping=True)
    decoded = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    return ip.postprocess_batch(decoded, lang="eng_Latn")[0]

def generate_structured_json(text_to_analyze):
    model = genai.GenerativeModel('gemini-1.5-flash-latest')
    prompt = f"Analyze this document and extract key info as JSON: {text_to_analyze}"
    json_schema = {
        "type": "OBJECT",
        "properties": {
            "summary": {"type": "STRING"},
            "actions_required": {"type": "ARRAY", "items": {
                "type": "OBJECT",
                "properties": {"action": {"type": "STRING"}, "priority": {"type": "STRING", "enum": ["High","Medium","Low"]}, "deadline": {"type": "STRING"}, "notes": {"type": "STRING"}},
                "required": ["action","priority","deadline","notes"]
            }},
            "departments_to_notify": {"type": "ARRAY", "items": {"type": "STRING"}},
            "cross_document_flags": {"type": "ARRAY", "items": {
                "type": "OBJECT",
                "properties": {"related_document_type": {"type": "STRING"}, "related_issue": {"type": "STRING"}},
                "required": ["related_document_type","related_issue"]
            }}
        },
        "required": ["summary","actions_required","departments_to_notify","cross_document_flags"]
    }
    generation_config = genai.types.GenerationConfig(response_mime_type="application/json", response_schema=json_schema)
    response = model.generate_content(prompt, generation_config=generation_config)
    return json.loads(response.text)

def check_relevance_with_gemini(summary_text):
    model = genai.GenerativeModel('gemini-1.5-flash-latest')
    prompt = f'Is this summary relevant to transportation, infrastructure, railways, or metro systems? Answer "Yes" or "No". Summary: {summary_text}'
    response = model.generate_content(prompt)
    return "yes" in response.text.strip().lower()

# === API INPUT SCHEMA ===
class InputFile(BaseModel):
    file_path: str

@app.post("/predict")
def predict(file: InputFile):
    if not GEMINI_API_KEY:
        return {"error": "Gemini API key not set."}
    path = file.file_path
    ext = os.path.splitext(path)[1].lower()
    
    # Phase 1: Extract text
    if ext == ".pdf":
        original_text = extract_text_from_pdf(path)
    elif ext == ".txt":
        original_text = read_text_from_txt(path)
    elif ext in [".png", ".jpg", ".jpeg"]:
        original_text = extract_text_from_image(path)
    else:
        return {"error": "Unsupported file type."}
    
    # Phase 2: Translate Malayalam if detected
    lines = original_text.split("\n")
    translated_lines = []
    for ln in lines:
        if not ln.strip(): continue
        lang = detect_language(ln)
        if lang == LANGUAGE_TO_TRANSLATE:
            translated_lines.append(translate_chunk(ln))
        else:
            translated_lines.append(ln)
    final_text = "\n".join(translated_lines)
    
    # Phase 3: Gemini analysis
    summary_data = generate_structured_json(final_text)
    if not summary_data or "summary" not in summary_data:
        return {"error": "Failed to generate analysis."}
    
    is_relevant = check_relevance_with_gemini(summary_data["summary"])
    if is_relevant:
        return summary_data
    else:
        return {"status": "Not Applicable", "reason": "Document not relevant to KMRL."}