Spaces:
Runtime error
Runtime error
| # ------------------------------ | |
| # AI Multi‑Modal Assistant — Phase 2 (OCR Added) | |
| # AI Multi‑Modal Assistant — Enhanced Version | |
| # ------------------------------ | |
| import gradio as gr | |
| from transformers import pipeline, pipeline as hf_pipeline, Wav2Vec2Processor, TFWav2Vec2ForCTC | |
| from PIL import Image | |
| import torch | |
| from torchvision import models, transforms | |
| import pandas as pd | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.pdfgen import canvas | |
| import io | |
| from textwrap import wrap | |
| import yake | |
| import tempfile | |
| import yake # keyword extraction | |
| import tempfile | |
| from fpdf import FPDF | |
| import time | |
| import tensorflow as tf | |
| import soundfile as sf | |
| from langdetect import detect, DetectorFactory | |
| DetectorFactory.seed = 0 # Consistent language detection | |
| import os | |
| import subprocess | |
| # Ensure tesseract is available at runtime | |
| try: | |
| subprocess.run(["tesseract", "--version"], check=True, stdout=subprocess.PIPE) | |
| except (subprocess.CalledProcessError, FileNotFoundError): | |
| print("Installing tesseract runtime...") | |
| os.system("apt-get update -y && apt-get install -y tesseract-ocr") | |
| import pytesseract # <-- OCR | |
| # ------------------------------ | |
| # 1. Load Models & Labels | |
| # ------------------------------ | |
| # NLP pipelines | |
| sentiment_model = pipeline( | |
| "sentiment-analysis", | |
| model="distilbert/distilbert-base-uncased-finetuned-sst-2-english", | |
| ) | |
| # Summarization Model | |
| summarizer_model = pipeline("summarization", model="facebook/bart-large-cnn") | |
| # Image classification model | |
| image_model = models.resnet50(pretrained=True) | |
| image_model.eval() | |
| preprocess = transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| # Load ImageNet class labels | |
| with open("imagenet_classes.txt", "r") as f: # ensure this file is in your folder | |
| imagenet_labels = [s.strip() for s in f.readlines()] | |
| # Keyword extraction | |
| kw_extractor = yake.KeywordExtractor(lan="en", top=5) | |
| # ------------------------------ | |
| # 2. Helper Functions | |
| # ------------------------------ | |
| def analyze_text(text: str) -> dict: | |
| sentiment = sentiment_model(text)[0] | |
| summary = summarizer_model( | |
| text, max_length=min(len(text.split()) + 10, 50), min_length=5 | |
| )[0]["summary_text"] | |
| keywords = [kw for kw, score in kw_extractor.extract_keywords(text)] | |
| return { | |
| "Sentiment": sentiment["label"], | |
| "Sentiment Score": round(sentiment["score"], 3), | |
| "Summary": summary, | |
| "Keywords": keywords, | |
| } | |
| def analyze_image(image: Image.Image) -> dict: | |
| img_t = preprocess(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = image_model(img_t) | |
| class_idx = outputs.argmax().item() | |
| class_label = imagenet_labels[class_idx] if 0 <= class_idx < len(imagenet_labels) else f"Class index {class_idx}" | |
| return { "Predicted Class Index": class_idx, "Predicted Class Label": class_label } | |
| def ocr_image(image: Image.Image) -> dict: | |
| """Extract text from uploaded image using Tesseract OCR.""" | |
| text = pytesseract.image_to_string(image) | |
| return {"Extracted Text": text} | |
| if 0 <= class_idx < len(imagenet_labels): | |
| class_label = imagenet_labels[class_idx] | |
| else: | |
| class_label = f"Class index {class_idx}" | |
| return { "Predicted Class Index": class_idx, "Predicted Class Label": class_label } | |
| def generate_pdf(results: dict) -> str: | |
| buffer = io.BytesIO() | |
| c = canvas.Canvas(buffer, pagesize=letter) | |
| width, height = letter | |
| c.setFont("Helvetica", 12) | |
| c.drawString(50, height - 50, "AI Multi-Modal Assistant Report") | |
| y = height -80 # Starting y position | |
| for key, value in results.items(): | |
| # Convert any complex types to string | |
| if isinstance(value, (list, dict)): | |
| value = str(value) | |
| # Wrap long lines (max 90 chars per line) | |
| wrapped_lines = wrap(f"{key}: {value}", width=90) | |
| for line in wrapped_lines: | |
| c.drawString(50, y, line) | |
| y -= 18 | |
| if y < 60: # New page if needed | |
| c.showPage() | |
| c.setFont("Helvetica", 12) | |
| y = height - 50 | |
| y -= 10 # Space between entries | |
| c.save() | |
| buffer.seek(0) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as tmp: | |
| tmp.write(buffer.getvalue()) | |
| tmp_path = tmp.name | |
| return tmp_path | |
| # ------------------------------ | |
| # 3. Multi‑Modal Analysis Function | |
| # ------------------------------ | |
| def analyze(input_data): | |
| # Handles both text and image correctly in Gradio | |
| if isinstance(input_data, str) and input_data.strip(): | |
| return analyze_text(input_data) | |
| elif isinstance(input_data, dict) and "image" in input_data: | |
| return analyze_image(input_data["image"]) | |
| elif isinstance(input_data, Image.Image): | |
| return analyze_image(input_data) | |
| else: | |
| return {"Error": "Please enter text or upload an image."} | |
| # ------------------------------ | |
| # 4. Gradio UI Layout | |
| # ------------------------------ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## AI Multi‑Modal Assistant") | |
| # ------------------ Image Analysis Tab ------------------ | |
| with gr.Tab("Image Analysis"): | |
| image_input = gr.Image(type="pil", label="Upload an image for classification") | |
| analyze_image_button = gr.Button("Analyze Image") # ✅ MOVED INSIDE the tab | |
| image_output = gr.JSON(label="Image Analysis Results") | |
| pdf_button_image = gr.Button("Download Report (PDF)") | |
| analyze_image_button.click(fn=analyze, inputs=image_input, outputs=image_output) | |
| pdf_button_image.click( | |
| fn=lambda x: generate_pdf(analyze(x)), | |
| inputs=image_input, | |
| outputs=gr.File(label="Download PDF Report"), | |
| ) | |
| with gr.Tab("Text Analysis"): | |
| text_input = gr.Textbox( | |
| label="Enter text to analyze", | |
| placeholder="Type your text here...", | |
| lines=5 | |
| ) | |
| analyze_text_button = gr.Button("Analyze Text") | |
| text_output = gr.JSON(label="Text Analysis Results") | |
| pdf_button_text = gr.Button("Download Report (PDF)") | |
| # Text analysis events | |
| analyze_text_button.click( | |
| fn=analyze, | |
| inputs=text_input, | |
| outputs=text_output | |
| ) | |
| pdf_button_text.click( | |
| fn=lambda x: generate_pdf(analyze(x)), | |
| inputs=text_input, | |
| outputs=gr.File(label="Download PDF Report") | |
| ) | |
| # ------------------ OCR Tab ------------------ | |
| with gr.Tab("OCR"): | |
| ocr_input = gr.Image(type="pil", label="Upload image for OCR") | |
| ocr_button = gr.Button("Extract Text (OCR)") | |
| ocr_output = gr.JSON(label="OCR Results") | |
| pdf_button_ocr = gr.Button("Download OCR PDF") | |
| ocr_button.click(fn=ocr_image, inputs=ocr_input, outputs=ocr_output) | |
| pdf_button_ocr.click( | |
| fn=lambda x: generate_pdf(x), | |
| inputs=ocr_output, | |
| outputs=gr.File(label="Download PDF Report"), | |
| ) | |
| # -------------------------------- | |
| # Enhanced Tab: Voice Input (Speech-to-Text + Translation + Language Detection) | |
| # ------------------------------ | |
| # ------------------------------ | |
| # PDF Generator with Unicode support | |
| # ------------------------------ | |
| # ------------------------------ | |
| # Auto-download DejaVuSans.ttf for all Unicode scripts | |
| # ------------------------------ | |
| FONT_PATH = "DejaVuSans.ttf" | |
| if not os.path.exists(FONT_PATH): | |
| import requests | |
| url = "https://github.com/dejavu-fonts/dejavu-fonts/raw/master/ttf/DejaVuSans.ttf" | |
| r = requests.get(url) | |
| with open(FONT_PATH, "wb") as f: | |
| f.write(r.content) | |
| # ------------------------------ | |
| # PDF generator with full Unicode support | |
| # ------------------------------ | |
| def generate_pdf(content_dict, output_file=None): | |
| pdf = FPDF() | |
| pdf.add_page() | |
| pdf.add_font("DejaVu", "", FONT_PATH, uni=True) | |
| pdf.set_font("DejaVu", size=12) | |
| for key, value in content_dict.items(): | |
| pdf.set_font("DejaVu", "B", 12) | |
| pdf.cell(0, 10, f"{key}:", ln=True) | |
| pdf.set_font("DejaVu", "", 12) | |
| pdf.multi_cell(0, 10, value) | |
| pdf.ln(5) | |
| if not output_file: | |
| output_file = f"transcript_{int(time.time())}.pdf" | |
| pdf.output(output_file) | |
| return output_file | |
| # ------------------------------ | |
| # TF-based STT (Wav2Vec2 for multilingual) | |
| # ------------------------------ | |
| processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53") | |
| stt_model = TFWav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-xlsr-53") | |
| def transcribe_audio(file_path): | |
| speech, _ = sf.read(file_path) | |
| input_values = processor(speech, return_tensors="tf", sampling_rate=16000).input_values | |
| logits = stt_model(input_values).logits | |
| predicted_ids = tf.argmax(logits, axis=-1) | |
| transcription = processor.batch_decode(predicted_ids)[0] | |
| return transcription | |
| # ------------------------------ | |
| # Language detection | |
| # ------------------------------ | |
| def detect_language(text): | |
| try: | |
| lang_code = detect(text) | |
| # Convert language code to full language name (for PDF) | |
| import pycountry | |
| lang = pycountry.languages.get(alpha_2=lang_code) | |
| return lang.name if lang else lang_code | |
| except: | |
| return "unknown" | |
| # ------------------------------ | |
| # TensorFlow translation (multilingual -> English) | |
| # ------------------------------ | |
| model_name = "Helsinki-NLP/opus-mt-mul-en" | |
| translator = pipeline("translation", model=model_name, framework="tf") | |
| # ------------------------------ | |
| # Gradio App | |
| # ------------------------------ | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Voice Input"): | |
| gr.Markdown("### 🎤 Multilingual Speech-to-Text + Translation + PDF") | |
| audio_input = gr.Audio(sources=["microphone", "upload"], type="filepath", label="🎙️ Record or upload an audio file") | |
| transcribe_button = gr.Button("Transcribe, Detect & Translate") | |
| detected_lang_output = gr.Textbox(label="🌐 Detected Language", lines=1) | |
| transcription_output = gr.Textbox(label="🗣️ Original Transcription", lines=3) | |
| translation_output = gr.Textbox(label="🌍 Translated to English", lines=3) | |
| pdf_button_audio = gr.Button("Download Transcript (PDF)") | |
| def transcribe_detect_translate(audio_file): | |
| if not audio_file: | |
| return "", "Please record or upload an audio file.", "" | |
| # Step 1: Transcription | |
| original_text = transcribe_audio(audio_file) | |
| if not original_text.strip(): | |
| return "", "No speech detected.", "" | |
| # Step 2: Language detection | |
| detected_lang = detect_language(original_text) | |
| # Step 3: Translate if not English | |
| translated_text = ( | |
| translator(original_text)[0]["translation_text"] | |
| if detected_lang.lower() != "english" | |
| else original_text | |
| ) | |
| return detected_lang, original_text, translated_text | |
| transcribe_button.click( | |
| fn=transcribe_detect_translate, | |
| inputs=audio_input, | |
| outputs=[detected_lang_output, transcription_output, translation_output] | |
| ) | |
| pdf_button_audio.click( | |
| fn=lambda lang, orig, trans: generate_pdf({ | |
| "Detected Language": lang, | |
| "Original Transcription": orig, | |
| "Translated to English": trans | |
| }), | |
| inputs=[detected_lang_output, transcription_output, translation_output], | |
| outputs=gr.File(label="Download PDF Report") | |
| ) | |
| demo.launch() | |