tayyab-077's picture
updated
38e4093
# ------------------------------
# AI Multi‑Modal Assistant — Phase 2 (OCR Added)
# AI Multi‑Modal Assistant — Enhanced Version
# ------------------------------
import gradio as gr
from transformers import pipeline, pipeline as hf_pipeline, Wav2Vec2Processor, TFWav2Vec2ForCTC
from PIL import Image
import torch
from torchvision import models, transforms
import pandas as pd
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
import io
from textwrap import wrap
import yake
import tempfile
import yake # keyword extraction
import tempfile
from fpdf import FPDF
import time
import tensorflow as tf
import soundfile as sf
from langdetect import detect, DetectorFactory
DetectorFactory.seed = 0 # Consistent language detection
import os
import subprocess
# Ensure tesseract is available at runtime
try:
subprocess.run(["tesseract", "--version"], check=True, stdout=subprocess.PIPE)
except (subprocess.CalledProcessError, FileNotFoundError):
print("Installing tesseract runtime...")
os.system("apt-get update -y && apt-get install -y tesseract-ocr")
import pytesseract # <-- OCR
# ------------------------------
# 1. Load Models & Labels
# ------------------------------
# NLP pipelines
sentiment_model = pipeline(
"sentiment-analysis",
model="distilbert/distilbert-base-uncased-finetuned-sst-2-english",
)
# Summarization Model
summarizer_model = pipeline("summarization", model="facebook/bart-large-cnn")
# Image classification model
image_model = models.resnet50(pretrained=True)
image_model.eval()
preprocess = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Load ImageNet class labels
with open("imagenet_classes.txt", "r") as f: # ensure this file is in your folder
imagenet_labels = [s.strip() for s in f.readlines()]
# Keyword extraction
kw_extractor = yake.KeywordExtractor(lan="en", top=5)
# ------------------------------
# 2. Helper Functions
# ------------------------------
def analyze_text(text: str) -> dict:
sentiment = sentiment_model(text)[0]
summary = summarizer_model(
text, max_length=min(len(text.split()) + 10, 50), min_length=5
)[0]["summary_text"]
keywords = [kw for kw, score in kw_extractor.extract_keywords(text)]
return {
"Sentiment": sentiment["label"],
"Sentiment Score": round(sentiment["score"], 3),
"Summary": summary,
"Keywords": keywords,
}
def analyze_image(image: Image.Image) -> dict:
img_t = preprocess(image).unsqueeze(0)
with torch.no_grad():
outputs = image_model(img_t)
class_idx = outputs.argmax().item()
class_label = imagenet_labels[class_idx] if 0 <= class_idx < len(imagenet_labels) else f"Class index {class_idx}"
return { "Predicted Class Index": class_idx, "Predicted Class Label": class_label }
def ocr_image(image: Image.Image) -> dict:
"""Extract text from uploaded image using Tesseract OCR."""
text = pytesseract.image_to_string(image)
return {"Extracted Text": text}
if 0 <= class_idx < len(imagenet_labels):
class_label = imagenet_labels[class_idx]
else:
class_label = f"Class index {class_idx}"
return { "Predicted Class Index": class_idx, "Predicted Class Label": class_label }
def generate_pdf(results: dict) -> str:
buffer = io.BytesIO()
c = canvas.Canvas(buffer, pagesize=letter)
width, height = letter
c.setFont("Helvetica", 12)
c.drawString(50, height - 50, "AI Multi-Modal Assistant Report")
y = height -80 # Starting y position
for key, value in results.items():
# Convert any complex types to string
if isinstance(value, (list, dict)):
value = str(value)
# Wrap long lines (max 90 chars per line)
wrapped_lines = wrap(f"{key}: {value}", width=90)
for line in wrapped_lines:
c.drawString(50, y, line)
y -= 18
if y < 60: # New page if needed
c.showPage()
c.setFont("Helvetica", 12)
y = height - 50
y -= 10 # Space between entries
c.save()
buffer.seek(0)
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp:
tmp.write(buffer.getvalue())
tmp_path = tmp.name
return tmp_path
# ------------------------------
# 3. Multi‑Modal Analysis Function
# ------------------------------
def analyze(input_data):
# Handles both text and image correctly in Gradio
if isinstance(input_data, str) and input_data.strip():
return analyze_text(input_data)
elif isinstance(input_data, dict) and "image" in input_data:
return analyze_image(input_data["image"])
elif isinstance(input_data, Image.Image):
return analyze_image(input_data)
else:
return {"Error": "Please enter text or upload an image."}
# ------------------------------
# 4. Gradio UI Layout
# ------------------------------
with gr.Blocks() as demo:
gr.Markdown("## AI Multi‑Modal Assistant")
# ------------------ Image Analysis Tab ------------------
with gr.Tab("Image Analysis"):
image_input = gr.Image(type="pil", label="Upload an image for classification")
analyze_image_button = gr.Button("Analyze Image") # ✅ MOVED INSIDE the tab
image_output = gr.JSON(label="Image Analysis Results")
pdf_button_image = gr.Button("Download Report (PDF)")
analyze_image_button.click(fn=analyze, inputs=image_input, outputs=image_output)
pdf_button_image.click(
fn=lambda x: generate_pdf(analyze(x)),
inputs=image_input,
outputs=gr.File(label="Download PDF Report"),
)
with gr.Tab("Text Analysis"):
text_input = gr.Textbox(
label="Enter text to analyze",
placeholder="Type your text here...",
lines=5
)
analyze_text_button = gr.Button("Analyze Text")
text_output = gr.JSON(label="Text Analysis Results")
pdf_button_text = gr.Button("Download Report (PDF)")
# Text analysis events
analyze_text_button.click(
fn=analyze,
inputs=text_input,
outputs=text_output
)
pdf_button_text.click(
fn=lambda x: generate_pdf(analyze(x)),
inputs=text_input,
outputs=gr.File(label="Download PDF Report")
)
# ------------------ OCR Tab ------------------
with gr.Tab("OCR"):
ocr_input = gr.Image(type="pil", label="Upload image for OCR")
ocr_button = gr.Button("Extract Text (OCR)")
ocr_output = gr.JSON(label="OCR Results")
pdf_button_ocr = gr.Button("Download OCR PDF")
ocr_button.click(fn=ocr_image, inputs=ocr_input, outputs=ocr_output)
pdf_button_ocr.click(
fn=lambda x: generate_pdf(x),
inputs=ocr_output,
outputs=gr.File(label="Download PDF Report"),
)
# --------------------------------
# Enhanced Tab: Voice Input (Speech-to-Text + Translation + Language Detection)
# ------------------------------
# ------------------------------
# PDF Generator with Unicode support
# ------------------------------
# ------------------------------
# Auto-download DejaVuSans.ttf for all Unicode scripts
# ------------------------------
FONT_PATH = "DejaVuSans.ttf"
if not os.path.exists(FONT_PATH):
import requests
url = "https://github.com/dejavu-fonts/dejavu-fonts/raw/master/ttf/DejaVuSans.ttf"
r = requests.get(url)
with open(FONT_PATH, "wb") as f:
f.write(r.content)
# ------------------------------
# PDF generator with full Unicode support
# ------------------------------
def generate_pdf(content_dict, output_file=None):
pdf = FPDF()
pdf.add_page()
pdf.add_font("DejaVu", "", FONT_PATH, uni=True)
pdf.set_font("DejaVu", size=12)
for key, value in content_dict.items():
pdf.set_font("DejaVu", "B", 12)
pdf.cell(0, 10, f"{key}:", ln=True)
pdf.set_font("DejaVu", "", 12)
pdf.multi_cell(0, 10, value)
pdf.ln(5)
if not output_file:
output_file = f"transcript_{int(time.time())}.pdf"
pdf.output(output_file)
return output_file
# ------------------------------
# TF-based STT (Wav2Vec2 for multilingual)
# ------------------------------
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53")
stt_model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-xlsr-53")
def transcribe_audio(file_path):
speech, _ = sf.read(file_path)
input_values = processor(speech, return_tensors="tf", sampling_rate=16000).input_values
logits = stt_model(input_values).logits
predicted_ids = tf.argmax(logits, axis=-1)
transcription = processor.batch_decode(predicted_ids)[0]
return transcription
# ------------------------------
# Language detection
# ------------------------------
def detect_language(text):
try:
lang_code = detect(text)
# Convert language code to full language name (for PDF)
import pycountry
lang = pycountry.languages.get(alpha_2=lang_code)
return lang.name if lang else lang_code
except:
return "unknown"
# ------------------------------
# TensorFlow translation (multilingual -> English)
# ------------------------------
model_name = "Helsinki-NLP/opus-mt-mul-en"
translator = pipeline("translation", model=model_name, framework="tf")
# ------------------------------
# Gradio App
# ------------------------------
with gr.Blocks() as demo:
with gr.Tab("Voice Input"):
gr.Markdown("### 🎤 Multilingual Speech-to-Text + Translation + PDF")
audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="🎙️ Record or upload an audio file")
transcribe_button = gr.Button("Transcribe, Detect & Translate")
detected_lang_output = gr.Textbox(label="🌐 Detected Language", lines=1)
transcription_output = gr.Textbox(label="🗣️ Original Transcription", lines=3)
translation_output = gr.Textbox(label="🌍 Translated to English", lines=3)
pdf_button_audio = gr.Button("Download Transcript (PDF)")
def transcribe_detect_translate(audio_file):
if not audio_file:
return "", "Please record or upload an audio file.", ""
# Step 1: Transcription
original_text = transcribe_audio(audio_file)
if not original_text.strip():
return "", "No speech detected.", ""
# Step 2: Language detection
detected_lang = detect_language(original_text)
# Step 3: Translate if not English
translated_text = (
translator(original_text)[0]["translation_text"]
if detected_lang.lower() != "english"
else original_text
)
return detected_lang, original_text, translated_text
transcribe_button.click(
fn=transcribe_detect_translate,
inputs=audio_input,
outputs=[detected_lang_output, transcription_output, translation_output]
)
pdf_button_audio.click(
fn=lambda lang, orig, trans: generate_pdf({
"Detected Language": lang,
"Original Transcription": orig,
"Translated to English": trans
}),
inputs=[detected_lang_output, transcription_output, translation_output],
outputs=gr.File(label="Download PDF Report")
)
demo.launch()