OmniPhi / app.py
SeemG's picture
Update app.py
5c850a6 verified
# app.py
import gradio as gr
import torch
import logging
from PIL import Image
from transformers import pipeline
from models import OmniPhiModel, load_blip, load_clip, load_omniphi
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename='inference.log')
logger = logging.getLogger(__name__)
# Configuration
DEVICE = torch.device("cpu") # Use CPU to avoid OOM on HF Spaces
TORCH_DTYPE = torch.float32
BLIP_MODEL = "Salesforce/blip-image-captioning-base"
CLIP_MODEL = "openai/clip-vit-base-patch32"
PHI_MODEL = "microsoft/phi-3-mini-4k-instruct"
CHECKPOINT_DIR = "./checkpoints/Epoch_peft_1"
WHISPER_MODEL = "openai/whisper-small"
COLOR_THEME = "blue" # Can be "blue" or "orange"
# Initialize transcriber
def initialize_transcriber(model_name):
try:
transcriber = pipeline("automatic-speech-recognition", model=model_name, device=-1) # CPU
logger.info("Whisper transcriber initialized successfully")
return transcriber
except Exception as e:
logger.error(f"Failed to initialize Whisper transcriber: {e}")
return lambda audio: {"text": "Error: Transcriber not available"}
# Inference function (from HF app)
def generate_description(image, text_prompt, audio, model_choice, transcriber, blip_model, blip_processor, clip_model, clip_processor, omniphi_model, omniphi_tokenizer, device):
text_only = image is None and (text_prompt or audio)
if not text_only and image is None:
logger.error("No image provided")
return "Error: Please upload an image."
if (text_prompt is None or text_prompt.strip() == "") and audio is None:
logger.error("No text prompt or audio provided")
return "Error: Please provide either a text prompt or record your voice."
if text_prompt and audio:
logger.error("Both text prompt and audio provided")
return "Error: Please provide only one of text prompt or voice recording."
# Process prompt
if text_prompt:
prompt = text_prompt.strip().rstrip("?")
else:
try:
transcription = transcriber(audio)["text"]
prompt = transcription.strip().rstrip("?")
except Exception as e:
logger.error(f"Transcription failed: {e}")
return f"Error: Failed to transcribe audio: {e}"
detailed_prompt = f"Provide a detailed description of the scene in the image, including objects, colors, actions, and context: {prompt}"
logger.info(f"Processed prompt: {detailed_prompt}")
# BLIP inference
if model_choice == "BLIP":
if text_only:
return "Error: BLIP requires an image."
try:
image = image.convert("RGB").resize((224, 224))
inputs = blip_processor(images=image, text="A photo of", return_tensors="pt").to(device)
outputs = blip_model.generate(
**inputs,
max_length=200,
num_beams=5,
temperature=0.7,
no_repeat_ngram_size=3
)
output = blip_processor.decode(outputs[0], skip_special_tokens=True)
logger.info(f"BLIP output: {output}")
return output
except Exception as e:
logger.error(f"BLIP generation failed: {e}")
return f"Error: BLIP failed to generate description: {e}. Try OmniPhi model."
# OmniPhi inference
else:
if omniphi_model is None or omniphi_tokenizer is None:
logger.error("OmniPhi model loading failed")
return "Error: OmniPhi model could not be loaded. Due to resource constraints, Try BLIP."
omniphi_model.eval()
prompt = "[IMG] " + detailed_prompt if not text_only else detailed_prompt
try:
tokenized = omniphi_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128)
text_input_ids = tokenized["input_ids"].to(device)
attention_mask = tokenized["attention_mask"].to(device)
logger.info(f"Natural input_ids: {text_input_ids}")
if text_only:
image_embedding = None
else:
image = image.convert("RGB").resize((224, 224))
image_inputs = clip_processor(images=image, return_tensors="pt").to(device)
image_embedding = clip_model.get_image_features(**image_inputs)
fused_embeddings = omniphi_model(
text_input_ids=text_input_ids,
attention_mask=attention_mask,
image_embedding=image_embedding
)
generated_ids = omniphi_model.phi.generate(
inputs_embeds=fused_embeddings,
attention_mask=attention_mask,
max_new_tokens=100,
do_sample=False,
repetition_penalty=1.2
)
output = omniphi_tokenizer.decode(generated_ids[0], skip_special_tokens=True)
logger.info(f"OmniPhi output: {output}")
return output
except Exception as e:
logger.error(f"OmniPhi generation failed: {e}")
return f"Error: OmniPhi failed to generate description: {e}"
def create_gradio_interface(generate_fn, color_theme="blue"):
# Set color variables based on theme
if color_theme == "blue":
primary_color = "#1e40af"
icon_color = "#2563eb"
icon = "🔮"
app_name = "OmniPhi"
gradio_theme = gr.themes.Soft(primary_hue="blue", secondary_hue="gray")
else:
primary_color = "#9a3412"
icon_color = "#ea580c"
icon = "🔶"
app_name = "OmniPhi Orange"
gradio_theme = gr.themes.Soft(primary_hue="orange", secondary_hue="gray")
# Create Blocks interface with Gradio theme
with gr.Blocks(theme=gradio_theme) as iface:
# Header section (simplified, no gradients or shadows)
gr.Markdown(f"""
<div style=" background-color:#9fc5e8; text-align:center; padding:20px; border-radius:12px; margin-bottom:20px; ">
<h1 style="color:{primary_color}; font-size:2.2em; font-weight:700; margin-bottom:10px;">{icon} {app_name}</h1>
<p style="color:#334155; font-size:1.1em; line-height:1.5;">Advanced Multi-Modal AI with BLIP or finetuned LLM Integration</p>
</div>
""")
# Main container (simplified, no gradients or shadows)
gr.Markdown(f"""
<div style="border-radius:16px; padding:20px;">
""")
# Main content - parallel layout with Rows and Columns
with gr.Row(equal_height=True):
# Left column for image upload
with gr.Column():
gr.Markdown(f"""
<div style=" background-color:#9fc5e8; margin-bottom:10px; padding:15px; border-radius:10px;">
<span style="font-size:1.5em; margin-right:8px; color:{icon_color};">🖼️</span>
<span style="color:{primary_color}; font-weight:600; font-size:1.05em;">Upload Image</span>
</div>
""")
image_input = gr.Image(
type="pil",
label=None,
height=300, # Resize image display
width=300
)
gr.Markdown("</div>")
# Model selection below image
gr.Markdown(f"""
<div style="margin-top:15px; margin-bottom:10px; padding:15px; border-radius:10px; background-color:#9fc5e8;">
<span style="font-size:1.5em; margin-right:8px; color:{icon_color};">⚙️</span>
<span style="color:{primary_color}; font-weight:600; font-size:1.05em;">Model Selection</span>
</div>
""")
model_choice = gr.Radio(
choices=["BLIP", "OmniPhi"],
value="BLIP",
label="Model Choice",
interactive=True
)
gr.Markdown("</div>")
# Middle column for text and voice inputs
with gr.Column():
gr.Markdown(f"""
<div style="margin-bottom:10px; padding:15px; border-radius:10px; background-color:#9fc5e8;">
<span style="font-size:1.5em; margin-right:8px; color:{icon_color};">💬</span>
<span style="color:{primary_color}; font-weight:600; font-size:1.05em;">Text Instruction</span>
</div>
""")
text_input = gr.Textbox(
label=None,
placeholder="e.g., Describe this image in detail, focusing on the environment...",
lines=3
)
gr.Markdown("</div>")
gr.Markdown(f"""
<div style="margin-top:15px; margin-bottom:10px; padding:15px; border-radius:10px; background-color:#9fc5e8;">
<span style="font-size:1.5em; margin-right:8px; color:{icon_color};">🎙️</span>
<span style="color:{primary_color}; font-weight:600; font-size:1.05em;">Voice Instruction (optional)</span>
</div>
""")
audio_input = gr.Audio(
sources="microphone",
label=None
)
gr.Markdown("</div>")
# Generate button with Gradio styling
submit_btn = gr.Button(
"Generate Description",
variant="primary",
size="lg"
)
# Right column for output
with gr.Column():
gr.Markdown(f"""
<div style="margin-bottom:10px; padding:15px; border-radius:10px; background-color:#9fc5e8;">
<span style="font-size:1.5em; margin-right:8px; color:{icon_color};">✨</span>
<span style="color:{primary_color}; font-weight:600; font-size:1.05em;">Generated Description</span>
</div>
""")
output = gr.Textbox(
label=None,
lines=12,
placeholder="Your description will appear here after generation..."
)
gr.Markdown("</div>")
# Close the main container div
gr.Markdown("</div>")
# Footer
gr.Markdown(f"""
<div style="text-align:center; margin-top:30px; color:#64748b; font-size:0.9em; ">
Powered by OmniPhi. Upload your image and provide instructions through text or voice.
</div>
""")
# Connect the button to the function
submit_btn.click(
fn=generate_fn,
inputs=[image_input, text_input, audio_input, model_choice],
outputs=output
)
return iface
# Main execution
if __name__ == "__main__":
# Load models
transcriber = initialize_transcriber(WHISPER_MODEL)
blip_model, blip_processor = load_blip(BLIP_MODEL, DEVICE, TORCH_DTYPE)
clip_model, clip_processor = load_clip(CLIP_MODEL, DEVICE, TORCH_DTYPE)
omniphi_model, omniphi_tokenizer = load_omniphi(CHECKPOINT_DIR, PHI_MODEL, CLIP_MODEL, DEVICE)
# Define generate function
generate_fn = lambda image, text_prompt, audio, model_choice: generate_description(
image, text_prompt, audio, model_choice, transcriber, blip_model, blip_processor,
clip_model, clip_processor, omniphi_model, omniphi_tokenizer, DEVICE
)
# Launch Gradio interface
iface = create_gradio_interface(generate_fn, color_theme=COLOR_THEME)
iface.launch(server_name="0.0.0.0", server_port=7860)