ai-toolkit / app.py
darelwork356-oss
AI multi-model toolkit
0ed8a43
import gradio as gr
import torch
from transformers import (
TrOCRProcessor, VisionEncoderDecoderModel,
WhisperProcessor, WhisperForConditionalGeneration,
MarianMTModel, MarianTokenizer,
pipeline
)
from PIL import Image
import torchaudio
import numpy as np
print("Loading models...")
# OCR
ocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
ocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed")
# Speech to Text
whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
# Translation
trans_en_es_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")
trans_en_es_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-es")
trans_es_en_tokenizer = MarianTokenizer.from_pretrained("Helsinki-NLP/opus-mt-es-en")
trans_es_en_model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-es-en")
# Sentiment
sentiment = pipeline("sentiment-analysis")
# Summarization
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
print("Models loaded successfully")
def ocr_extract(image):
if image is None:
return "Error: No image"
pixel_values = ocr_processor(image, return_tensors="pt").pixel_values
generated_ids = ocr_model.generate(pixel_values)
text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return text
def speech_to_text(audio):
if audio is None:
return "Error: No audio"
audio_array, sr = torchaudio.load(audio)
if sr != 16000:
resampler = torchaudio.transforms.Resample(sr, 16000)
audio_array = resampler(audio_array)
inputs = whisper_processor(audio_array.squeeze().numpy(), sampling_rate=16000, return_tensors="pt")
generated_ids = whisper_model.generate(inputs.input_features)
text = whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return text
def translate_en_to_es(text):
if not text:
return "Error: No text"
inputs = trans_en_es_tokenizer(text, return_tensors="pt", padding=True)
translated = trans_en_es_model.generate(**inputs)
return trans_en_es_tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
def translate_es_to_en(text):
if not text:
return "Error: No text"
inputs = trans_es_en_tokenizer(text, return_tensors="pt", padding=True)
translated = trans_es_en_model.generate(**inputs)
return trans_es_en_tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
def analyze_sentiment(text):
if not text:
return "Error: No text"
result = sentiment(text)[0]
return f"{result['label']}: {result['score']:.2%}"
def summarize(text):
if not text or len(text) < 100:
return "Text too short"
result = summarizer(text, max_length=130, min_length=30, do_sample=False)
return result[0]['summary_text']
with gr.Blocks(title="AI Toolkit") as app:
gr.Markdown("# AI Multi-Model Toolkit")
with gr.Tab("OCR"):
with gr.Row():
ocr_img = gr.Image(type="pil", label="Image")
ocr_out = gr.Textbox(label="Text", lines=10)
gr.Button("Extract").click(ocr_extract, ocr_img, ocr_out)
with gr.Tab("Speech to Text"):
with gr.Row():
audio_in = gr.Audio(type="filepath", label="Audio")
audio_out = gr.Textbox(label="Transcription", lines=10)
gr.Button("Transcribe").click(speech_to_text, audio_in, audio_out)
with gr.Tab("Translation"):
with gr.Row():
with gr.Column():
trans_in = gr.Textbox(label="Input", lines=5)
with gr.Row():
btn_en_es = gr.Button("EN to ES")
btn_es_en = gr.Button("ES to EN")
trans_out = gr.Textbox(label="Output", lines=5)
btn_en_es.click(translate_en_to_es, trans_in, trans_out)
btn_es_en.click(translate_es_to_en, trans_in, trans_out)
with gr.Tab("Sentiment"):
with gr.Row():
sent_in = gr.Textbox(label="Text", lines=5)
sent_out = gr.Textbox(label="Result", lines=2)
gr.Button("Analyze").click(analyze_sentiment, sent_in, sent_out)
with gr.Tab("Summarize"):
with gr.Row():
summ_in = gr.Textbox(label="Long Text", lines=10)
summ_out = gr.Textbox(label="Summary", lines=5)
gr.Button("Summarize").click(summarize, summ_in, summ_out)
app.launch()