| |
| import gradio as gr |
| import torch |
| import logging |
| from PIL import Image |
| from transformers import pipeline |
| from models import OmniPhiModel, load_blip, load_clip, load_omniphi |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename='inference.log') |
| logger = logging.getLogger(__name__) |
|
|
| |
| DEVICE = torch.device("cpu") |
| TORCH_DTYPE = torch.float32 |
| BLIP_MODEL = "Salesforce/blip-image-captioning-base" |
| CLIP_MODEL = "openai/clip-vit-base-patch32" |
| PHI_MODEL = "microsoft/phi-3-mini-4k-instruct" |
| CHECKPOINT_DIR = "./checkpoints/Epoch_peft_1" |
| WHISPER_MODEL = "openai/whisper-small" |
| COLOR_THEME = "blue" |
|
|
| |
| def initialize_transcriber(model_name): |
| try: |
| transcriber = pipeline("automatic-speech-recognition", model=model_name, device=-1) |
| logger.info("Whisper transcriber initialized successfully") |
| return transcriber |
| except Exception as e: |
| logger.error(f"Failed to initialize Whisper transcriber: {e}") |
| return lambda audio: {"text": "Error: Transcriber not available"} |
|
|
| |
| def generate_description(image, text_prompt, audio, model_choice, transcriber, blip_model, blip_processor, clip_model, clip_processor, omniphi_model, omniphi_tokenizer, device): |
| text_only = image is None and (text_prompt or audio) |
| if not text_only and image is None: |
| logger.error("No image provided") |
| return "Error: Please upload an image." |
| if (text_prompt is None or text_prompt.strip() == "") and audio is None: |
| logger.error("No text prompt or audio provided") |
| return "Error: Please provide either a text prompt or record your voice." |
| if text_prompt and audio: |
| logger.error("Both text prompt and audio provided") |
| return "Error: Please provide only one of text prompt or voice recording." |
| |
| |
| if text_prompt: |
| prompt = text_prompt.strip().rstrip("?") |
| else: |
| try: |
| transcription = transcriber(audio)["text"] |
| prompt = transcription.strip().rstrip("?") |
| except Exception as e: |
| logger.error(f"Transcription failed: {e}") |
| return f"Error: Failed to transcribe audio: {e}" |
| |
| detailed_prompt = f"Provide a detailed description of the scene in the image, including objects, colors, actions, and context: {prompt}" |
| logger.info(f"Processed prompt: {detailed_prompt}") |
|
|
| |
| if model_choice == "BLIP": |
| if text_only: |
| return "Error: BLIP requires an image." |
| try: |
| image = image.convert("RGB").resize((224, 224)) |
| inputs = blip_processor(images=image, text="A photo of", return_tensors="pt").to(device) |
| outputs = blip_model.generate( |
| **inputs, |
| max_length=200, |
| num_beams=5, |
| temperature=0.7, |
| no_repeat_ngram_size=3 |
| ) |
| output = blip_processor.decode(outputs[0], skip_special_tokens=True) |
| logger.info(f"BLIP output: {output}") |
| return output |
| except Exception as e: |
| logger.error(f"BLIP generation failed: {e}") |
| return f"Error: BLIP failed to generate description: {e}. Try OmniPhi model." |
|
|
| |
| else: |
| if omniphi_model is None or omniphi_tokenizer is None: |
| logger.error("OmniPhi model loading failed") |
| return "Error: OmniPhi model could not be loaded. Due to resource constraints, Try BLIP." |
| omniphi_model.eval() |
| prompt = "[IMG] " + detailed_prompt if not text_only else detailed_prompt |
| try: |
| tokenized = omniphi_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128) |
| text_input_ids = tokenized["input_ids"].to(device) |
| attention_mask = tokenized["attention_mask"].to(device) |
| logger.info(f"Natural input_ids: {text_input_ids}") |
| |
| if text_only: |
| image_embedding = None |
| else: |
| image = image.convert("RGB").resize((224, 224)) |
| image_inputs = clip_processor(images=image, return_tensors="pt").to(device) |
| image_embedding = clip_model.get_image_features(**image_inputs) |
| |
| fused_embeddings = omniphi_model( |
| text_input_ids=text_input_ids, |
| attention_mask=attention_mask, |
| image_embedding=image_embedding |
| ) |
| generated_ids = omniphi_model.phi.generate( |
| inputs_embeds=fused_embeddings, |
| attention_mask=attention_mask, |
| max_new_tokens=100, |
| do_sample=False, |
| repetition_penalty=1.2 |
| ) |
| output = omniphi_tokenizer.decode(generated_ids[0], skip_special_tokens=True) |
| logger.info(f"OmniPhi output: {output}") |
| return output |
| except Exception as e: |
| logger.error(f"OmniPhi generation failed: {e}") |
| return f"Error: OmniPhi failed to generate description: {e}" |
|
|
| |
| |
| def create_gradio_interface(generate_fn, color_theme="blue"): |
| |
| if color_theme == "blue": |
| primary_color = "#1e40af" |
| icon_color = "#2563eb" |
| icon = "🔮" |
| app_name = "OmniPhi" |
| gradio_theme = gr.themes.Soft(primary_hue="blue", secondary_hue="gray") |
| else: |
| primary_color = "#9a3412" |
| icon_color = "#ea580c" |
| icon = "🔶" |
| app_name = "OmniPhi Orange" |
| gradio_theme = gr.themes.Soft(primary_hue="orange", secondary_hue="gray") |
|
|
| |
| with gr.Blocks(theme=gradio_theme) as iface: |
| |
| gr.Markdown(f""" |
| <div style=" background-color:#9fc5e8; text-align:center; padding:20px; border-radius:12px; margin-bottom:20px; "> |
| <h1 style="color:{primary_color}; font-size:2.2em; font-weight:700; margin-bottom:10px;">{icon} {app_name}</h1> |
| <p style="color:#334155; font-size:1.1em; line-height:1.5;">Advanced Multi-Modal AI with BLIP or finetuned LLM Integration</p> |
| </div> |
| """) |
| |
| |
| gr.Markdown(f""" |
| <div style="border-radius:16px; padding:20px;"> |
| """) |
| |
| |
| with gr.Row(equal_height=True): |
| |
| with gr.Column(): |
| gr.Markdown(f""" |
| <div style=" background-color:#9fc5e8; margin-bottom:10px; padding:15px; border-radius:10px;"> |
| <span style="font-size:1.5em; margin-right:8px; color:{icon_color};">🖼️</span> |
| <span style="color:{primary_color}; font-weight:600; font-size:1.05em;">Upload Image</span> |
| </div> |
| """) |
| image_input = gr.Image( |
| type="pil", |
| label=None, |
| height=300, |
| width=300 |
| ) |
| gr.Markdown("</div>") |
| |
| |
| gr.Markdown(f""" |
| <div style="margin-top:15px; margin-bottom:10px; padding:15px; border-radius:10px; background-color:#9fc5e8;"> |
| <span style="font-size:1.5em; margin-right:8px; color:{icon_color};">⚙️</span> |
| <span style="color:{primary_color}; font-weight:600; font-size:1.05em;">Model Selection</span> |
| </div> |
| """) |
| model_choice = gr.Radio( |
| choices=["BLIP", "OmniPhi"], |
| value="BLIP", |
| label="Model Choice", |
| interactive=True |
| ) |
| gr.Markdown("</div>") |
| |
| |
| with gr.Column(): |
| gr.Markdown(f""" |
| <div style="margin-bottom:10px; padding:15px; border-radius:10px; background-color:#9fc5e8;"> |
| <span style="font-size:1.5em; margin-right:8px; color:{icon_color};">💬</span> |
| <span style="color:{primary_color}; font-weight:600; font-size:1.05em;">Text Instruction</span> |
| </div> |
| """) |
| text_input = gr.Textbox( |
| label=None, |
| placeholder="e.g., Describe this image in detail, focusing on the environment...", |
| lines=3 |
| ) |
| gr.Markdown("</div>") |
| |
| gr.Markdown(f""" |
| <div style="margin-top:15px; margin-bottom:10px; padding:15px; border-radius:10px; background-color:#9fc5e8;"> |
| <span style="font-size:1.5em; margin-right:8px; color:{icon_color};">🎙️</span> |
| <span style="color:{primary_color}; font-weight:600; font-size:1.05em;">Voice Instruction (optional)</span> |
| </div> |
| """) |
| audio_input = gr.Audio( |
| sources="microphone", |
| label=None |
| ) |
| gr.Markdown("</div>") |
| |
| |
| submit_btn = gr.Button( |
| "Generate Description", |
| variant="primary", |
| size="lg" |
| ) |
| |
| |
| with gr.Column(): |
| gr.Markdown(f""" |
| <div style="margin-bottom:10px; padding:15px; border-radius:10px; background-color:#9fc5e8;"> |
| <span style="font-size:1.5em; margin-right:8px; color:{icon_color};">✨</span> |
| <span style="color:{primary_color}; font-weight:600; font-size:1.05em;">Generated Description</span> |
| </div> |
| """) |
| output = gr.Textbox( |
| label=None, |
| lines=12, |
| placeholder="Your description will appear here after generation..." |
| ) |
| gr.Markdown("</div>") |
| |
| |
| gr.Markdown("</div>") |
| |
| |
| gr.Markdown(f""" |
| <div style="text-align:center; margin-top:30px; color:#64748b; font-size:0.9em; "> |
| Powered by OmniPhi. Upload your image and provide instructions through text or voice. |
| </div> |
| """) |
| |
| |
| submit_btn.click( |
| fn=generate_fn, |
| inputs=[image_input, text_input, audio_input, model_choice], |
| outputs=output |
| ) |
| |
| return iface |
|
|
|
|
|
|
| |
| if __name__ == "__main__": |
| |
| transcriber = initialize_transcriber(WHISPER_MODEL) |
| blip_model, blip_processor = load_blip(BLIP_MODEL, DEVICE, TORCH_DTYPE) |
| clip_model, clip_processor = load_clip(CLIP_MODEL, DEVICE, TORCH_DTYPE) |
| omniphi_model, omniphi_tokenizer = load_omniphi(CHECKPOINT_DIR, PHI_MODEL, CLIP_MODEL, DEVICE) |
|
|
| |
| generate_fn = lambda image, text_prompt, audio, model_choice: generate_description( |
| image, text_prompt, audio, model_choice, transcriber, blip_model, blip_processor, |
| clip_model, clip_processor, omniphi_model, omniphi_tokenizer, DEVICE |
| ) |
|
|
| |
| iface = create_gradio_interface(generate_fn, color_theme=COLOR_THEME) |
| iface.launch(server_name="0.0.0.0", server_port=7860) |