Spaces:
Runtime error
Runtime error
Commit
·
a66fd49
1
Parent(s):
2b903c0
updated app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
# ------------------------------
|
| 2 |
-
# AI Multi‑Modal Assistant —
|
| 3 |
# ------------------------------
|
| 4 |
|
| 5 |
import gradio as gr
|
|
@@ -11,8 +11,9 @@ import pandas as pd
|
|
| 11 |
from reportlab.lib.pagesizes import letter
|
| 12 |
from reportlab.pdfgen import canvas
|
| 13 |
import io
|
| 14 |
-
import yake
|
| 15 |
import tempfile
|
|
|
|
| 16 |
|
| 17 |
# ------------------------------
|
| 18 |
# 1. Load Models & Labels
|
|
@@ -23,9 +24,6 @@ sentiment_model = pipeline(
|
|
| 23 |
"sentiment-analysis",
|
| 24 |
model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
|
| 25 |
)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# Summarization Model
|
| 29 |
summarizer_model = pipeline("summarization", model="facebook/bart-large-cnn")
|
| 30 |
|
| 31 |
# Image classification model
|
|
@@ -40,9 +38,8 @@ preprocess = transforms.Compose(
|
|
| 40 |
]
|
| 41 |
)
|
| 42 |
|
| 43 |
-
# Load ImageNet class labels
|
| 44 |
-
|
| 45 |
-
with open("imagenet_classes.txt", "r") as f: # ensure this file is in your folder
|
| 46 |
imagenet_labels = [s.strip() for s in f.readlines()]
|
| 47 |
|
| 48 |
# Keyword extraction
|
|
@@ -52,7 +49,6 @@ kw_extractor = yake.KeywordExtractor(lan="en", top=5)
|
|
| 52 |
# 2. Helper Functions
|
| 53 |
# ------------------------------
|
| 54 |
|
| 55 |
-
|
| 56 |
def analyze_text(text: str) -> dict:
|
| 57 |
sentiment = sentiment_model(text)[0]
|
| 58 |
summary = summarizer_model(
|
|
@@ -66,18 +62,18 @@ def analyze_text(text: str) -> dict:
|
|
| 66 |
"Keywords": keywords,
|
| 67 |
}
|
| 68 |
|
| 69 |
-
|
| 70 |
def analyze_image(image: Image.Image) -> dict:
|
| 71 |
img_t = preprocess(image).unsqueeze(0)
|
| 72 |
with torch.no_grad():
|
| 73 |
outputs = image_model(img_t)
|
| 74 |
class_idx = outputs.argmax().item()
|
| 75 |
-
if 0 <= class_idx < len(imagenet_labels)
|
| 76 |
-
class_label = imagenet_labels[class_idx]
|
| 77 |
-
else:
|
| 78 |
-
class_label = f"Class index {class_idx}"
|
| 79 |
return {"Predicted Class Index": class_idx, "Predicted Class Label": class_label}
|
| 80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
def generate_pdf(results: dict) -> str:
|
| 83 |
buffer = io.BytesIO()
|
|
@@ -96,19 +92,16 @@ def generate_pdf(results: dict) -> str:
|
|
| 96 |
c.save()
|
| 97 |
buffer.seek(0)
|
| 98 |
|
| 99 |
-
# ✅ Save to a temp file and return path instead of buffer
|
| 100 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
|
| 101 |
tmp.write(buffer.getvalue())
|
| 102 |
tmp_path = tmp.name
|
| 103 |
|
| 104 |
-
return tmp_path
|
| 105 |
-
|
| 106 |
|
| 107 |
# ------------------------------
|
| 108 |
# 3. Multi‑Modal Analysis Function
|
| 109 |
# ------------------------------
|
| 110 |
def analyze(input_data):
|
| 111 |
-
# Handles both text and image correctly in Gradio
|
| 112 |
if isinstance(input_data, str) and input_data.strip():
|
| 113 |
return analyze_text(input_data)
|
| 114 |
elif isinstance(input_data, dict) and "image" in input_data:
|
|
@@ -118,7 +111,6 @@ def analyze(input_data):
|
|
| 118 |
else:
|
| 119 |
return {"Error": "Please enter text or upload an image."}
|
| 120 |
|
| 121 |
-
|
| 122 |
# ------------------------------
|
| 123 |
# 4. Gradio UI Layout
|
| 124 |
# ------------------------------
|
|
@@ -126,47 +118,54 @@ def analyze(input_data):
|
|
| 126 |
with gr.Blocks() as demo:
|
| 127 |
gr.Markdown("## AI Multi‑Modal Assistant")
|
| 128 |
|
|
|
|
| 129 |
with gr.Tab("Image Analysis"):
|
| 130 |
image_input = gr.Image(type="pil", label="Upload an image for classification")
|
| 131 |
-
analyze_image_button = gr.Button("Analyze Image")
|
| 132 |
-
|
| 133 |
image_output = gr.JSON(label="Image Analysis Results")
|
| 134 |
pdf_button_image = gr.Button("Download Report (PDF)")
|
| 135 |
|
| 136 |
analyze_image_button.click(fn=analyze, inputs=image_input, outputs=image_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
-
|
| 139 |
-
fn=lambda x: generate_pdf(analyze(x)),
|
| 140 |
-
inputs=image_input,
|
| 141 |
-
outputs=gr.File(label="Download PDF Report"),
|
| 142 |
-
)
|
| 143 |
-
|
| 144 |
-
|
| 145 |
with gr.Tab("Text Analysis"):
|
| 146 |
text_input = gr.Textbox(
|
| 147 |
label="Enter text to analyze",
|
| 148 |
placeholder="Type your text here...",
|
| 149 |
lines=5
|
| 150 |
)
|
| 151 |
-
|
| 152 |
analyze_text_button = gr.Button("Analyze Text")
|
| 153 |
text_output = gr.JSON(label="Text Analysis Results")
|
| 154 |
pdf_button_text = gr.Button("Download Report (PDF)")
|
| 155 |
|
| 156 |
-
# Text analysis events
|
| 157 |
-
|
| 158 |
analyze_text_button.click(
|
| 159 |
fn=analyze,
|
| 160 |
inputs=text_input,
|
| 161 |
outputs=text_output
|
| 162 |
)
|
| 163 |
-
|
| 164 |
pdf_button_text.click(
|
| 165 |
fn=lambda x: generate_pdf(analyze(x)),
|
| 166 |
inputs=text_input,
|
| 167 |
outputs=gr.File(label="Download PDF Report")
|
| 168 |
)
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
# ------------------------------
|
| 172 |
# 5. Launch the App
|
|
|
|
| 1 |
# ------------------------------
|
| 2 |
+
# AI Multi‑Modal Assistant — Phase 2 (OCR Added)
|
| 3 |
# ------------------------------
|
| 4 |
|
| 5 |
import gradio as gr
|
|
|
|
| 11 |
from reportlab.lib.pagesizes import letter
|
| 12 |
from reportlab.pdfgen import canvas
|
| 13 |
import io
|
| 14 |
+
import yake
|
| 15 |
import tempfile
|
| 16 |
+
import pytesseract # <-- OCR
|
| 17 |
|
| 18 |
# ------------------------------
|
| 19 |
# 1. Load Models & Labels
|
|
|
|
| 24 |
"sentiment-analysis",
|
| 25 |
model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
|
| 26 |
)
|
|
|
|
|
|
|
|
|
|
| 27 |
summarizer_model = pipeline("summarization", model="facebook/bart-large-cnn")
|
| 28 |
|
| 29 |
# Image classification model
|
|
|
|
| 38 |
]
|
| 39 |
)
|
| 40 |
|
| 41 |
+
# Load ImageNet class labels
|
| 42 |
+
with open("imagenet_classes.txt", "r") as f:
|
|
|
|
| 43 |
imagenet_labels = [s.strip() for s in f.readlines()]
|
| 44 |
|
| 45 |
# Keyword extraction
|
|
|
|
| 49 |
# 2. Helper Functions
|
| 50 |
# ------------------------------
|
| 51 |
|
|
|
|
| 52 |
def analyze_text(text: str) -> dict:
|
| 53 |
sentiment = sentiment_model(text)[0]
|
| 54 |
summary = summarizer_model(
|
|
|
|
| 62 |
"Keywords": keywords,
|
| 63 |
}
|
| 64 |
|
|
|
|
| 65 |
def analyze_image(image: Image.Image) -> dict:
|
| 66 |
img_t = preprocess(image).unsqueeze(0)
|
| 67 |
with torch.no_grad():
|
| 68 |
outputs = image_model(img_t)
|
| 69 |
class_idx = outputs.argmax().item()
|
| 70 |
+
class_label = imagenet_labels[class_idx] if 0 <= class_idx < len(imagenet_labels) else f"Class index {class_idx}"
|
|
|
|
|
|
|
|
|
|
| 71 |
return {"Predicted Class Index": class_idx, "Predicted Class Label": class_label}
|
| 72 |
|
| 73 |
+
def ocr_image(image: Image.Image) -> dict:
|
| 74 |
+
"""Extract text from uploaded image using Tesseract OCR."""
|
| 75 |
+
text = pytesseract.image_to_string(image)
|
| 76 |
+
return {"Extracted Text": text}
|
| 77 |
|
| 78 |
def generate_pdf(results: dict) -> str:
|
| 79 |
buffer = io.BytesIO()
|
|
|
|
| 92 |
c.save()
|
| 93 |
buffer.seek(0)
|
| 94 |
|
|
|
|
| 95 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
|
| 96 |
tmp.write(buffer.getvalue())
|
| 97 |
tmp_path = tmp.name
|
| 98 |
|
| 99 |
+
return tmp_path
|
|
|
|
| 100 |
|
| 101 |
# ------------------------------
|
| 102 |
# 3. Multi‑Modal Analysis Function
|
| 103 |
# ------------------------------
|
| 104 |
def analyze(input_data):
|
|
|
|
| 105 |
if isinstance(input_data, str) and input_data.strip():
|
| 106 |
return analyze_text(input_data)
|
| 107 |
elif isinstance(input_data, dict) and "image" in input_data:
|
|
|
|
| 111 |
else:
|
| 112 |
return {"Error": "Please enter text or upload an image."}
|
| 113 |
|
|
|
|
| 114 |
# ------------------------------
|
| 115 |
# 4. Gradio UI Layout
|
| 116 |
# ------------------------------
|
|
|
|
| 118 |
with gr.Blocks() as demo:
|
| 119 |
gr.Markdown("## AI Multi‑Modal Assistant")
|
| 120 |
|
| 121 |
+
# ------------------ Image Analysis Tab ------------------
|
| 122 |
with gr.Tab("Image Analysis"):
|
| 123 |
image_input = gr.Image(type="pil", label="Upload an image for classification")
|
| 124 |
+
analyze_image_button = gr.Button("Analyze Image")
|
|
|
|
| 125 |
image_output = gr.JSON(label="Image Analysis Results")
|
| 126 |
pdf_button_image = gr.Button("Download Report (PDF)")
|
| 127 |
|
| 128 |
analyze_image_button.click(fn=analyze, inputs=image_input, outputs=image_output)
|
| 129 |
+
pdf_button_image.click(
|
| 130 |
+
fn=lambda x: generate_pdf(analyze(x)),
|
| 131 |
+
inputs=image_input,
|
| 132 |
+
outputs=gr.File(label="Download PDF Report"),
|
| 133 |
+
)
|
| 134 |
|
| 135 |
+
# ------------------ Text Analysis Tab ------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
with gr.Tab("Text Analysis"):
|
| 137 |
text_input = gr.Textbox(
|
| 138 |
label="Enter text to analyze",
|
| 139 |
placeholder="Type your text here...",
|
| 140 |
lines=5
|
| 141 |
)
|
|
|
|
| 142 |
analyze_text_button = gr.Button("Analyze Text")
|
| 143 |
text_output = gr.JSON(label="Text Analysis Results")
|
| 144 |
pdf_button_text = gr.Button("Download Report (PDF)")
|
| 145 |
|
|
|
|
|
|
|
| 146 |
analyze_text_button.click(
|
| 147 |
fn=analyze,
|
| 148 |
inputs=text_input,
|
| 149 |
outputs=text_output
|
| 150 |
)
|
|
|
|
| 151 |
pdf_button_text.click(
|
| 152 |
fn=lambda x: generate_pdf(analyze(x)),
|
| 153 |
inputs=text_input,
|
| 154 |
outputs=gr.File(label="Download PDF Report")
|
| 155 |
)
|
| 156 |
|
| 157 |
+
# ------------------ OCR Tab ------------------
|
| 158 |
+
with gr.Tab("OCR"):
|
| 159 |
+
ocr_input = gr.Image(type="pil", label="Upload image for OCR")
|
| 160 |
+
ocr_output = gr.JSON(label="OCR Results")
|
| 161 |
+
pdf_button_ocr = gr.Button("Download OCR PDF")
|
| 162 |
+
|
| 163 |
+
ocr_input.submit(fn=ocr_image, inputs=ocr_input, outputs=ocr_output)
|
| 164 |
+
pdf_button_ocr.click(
|
| 165 |
+
fn=lambda x: generate_pdf(x),
|
| 166 |
+
inputs=ocr_output,
|
| 167 |
+
outputs=gr.File(label="Download PDF Report"),
|
| 168 |
+
)
|
| 169 |
|
| 170 |
# ------------------------------
|
| 171 |
# 5. Launch the App
|