Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, CLIPProcessor, CLIPModel | |
| from PIL import Image | |
| import logging | |
| import spaces | |
| import numpy | |
| # Setup logging | |
| logging.basicConfig(level=logging.INFO) | |
| class LLaVAPhiModel: | |
| def __init__(self, model_id="sagar007/Lava_phi"): | |
| self.device = "cuda" | |
| self.model_id = model_id | |
| logging.info("Initializing LLaVA-Phi model...") | |
| # Initialize tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| try: | |
| # Use CLIPProcessor directly instead of AutoProcessor | |
| self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| logging.info("Successfully loaded CLIP processor") | |
| except Exception as e: | |
| logging.error(f"Failed to load CLIP processor: {str(e)}") | |
| self.processor = None | |
| self.history = [] | |
| self.model = None | |
| self.clip = None | |
| def ensure_models_loaded(self): | |
| """Ensure models are loaded in GPU context""" | |
| if self.model is None: | |
| # Load main model with updated quantization config | |
| from transformers import BitsAndBytesConfig | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4" | |
| ) | |
| try: | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| self.model_id, | |
| quantization_config=quantization_config, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16, | |
| trust_remote_code=True | |
| ) | |
| self.model.config.pad_token_id = self.tokenizer.eos_token_id | |
| logging.info("Successfully loaded main model") | |
| except Exception as e: | |
| logging.error(f"Failed to load main model: {str(e)}") | |
| raise | |
| if self.clip is None: | |
| try: | |
| # Use CLIPModel directly instead of AutoModel | |
| self.clip = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device) | |
| logging.info("Successfully loaded CLIP model") | |
| except Exception as e: | |
| logging.error(f"Failed to load CLIP model: {str(e)}") | |
| self.clip = None | |
| def process_image(self, image): | |
| """Process image through CLIP if available""" | |
| try: | |
| self.ensure_models_loaded() | |
| if self.clip is None or self.processor is None: | |
| logging.warning("CLIP model or processor not available") | |
| return None | |
| # Convert image to correct format | |
| if isinstance(image, str): | |
| image = Image.open(image) | |
| elif isinstance(image, numpy.ndarray): | |
| image = Image.fromarray(image) | |
| # Ensure image is in RGB mode | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| with torch.no_grad(): | |
| try: | |
| # Process image with error handling | |
| image_inputs = self.processor(images=image, return_tensors="pt") | |
| image_features = self.clip.get_image_features( | |
| pixel_values=image_inputs.pixel_values.to(self.device) | |
| ) | |
| logging.info("Successfully processed image through CLIP") | |
| return image_features | |
| except Exception as e: | |
| logging.error(f"Error during image processing: {str(e)}") | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error in process_image: {str(e)}") | |
| return None | |
| def generate_response(self, message, image=None): | |
| try: | |
| self.ensure_models_loaded() | |
| if image is not None: | |
| image_features = self.process_image(image) | |
| has_image = image_features is not None | |
| if not has_image: | |
| message = "Note: Image processing is not available - continuing with text only.\n" + message | |
| prompt = f"human: {'<image>' if has_image else ''}\n{message}\ngpt:" | |
| context = "" | |
| for turn in self.history[-3:]: | |
| context += f"human: {turn[0]}\ngpt: {turn[1]}\n" | |
| full_prompt = context + prompt | |
| inputs = self.tokenizer( | |
| full_prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| if has_image: | |
| inputs["image_features"] = image_features | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| min_length=20, | |
| temperature=0.7, | |
| do_sample=True, | |
| top_p=0.9, | |
| top_k=40, | |
| repetition_penalty=1.5, | |
| no_repeat_ngram_size=3, | |
| use_cache=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| else: | |
| prompt = f"human: {message}\ngpt:" | |
| context = "" | |
| for turn in self.history[-3:]: | |
| context += f"human: {turn[0]}\ngpt: {turn[1]}\n" | |
| full_prompt = context + prompt | |
| inputs = self.tokenizer( | |
| full_prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=512 | |
| ) | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=150, | |
| min_length=20, | |
| temperature=0.6, | |
| do_sample=True, | |
| top_p=0.85, | |
| top_k=30, | |
| repetition_penalty=1.8, | |
| no_repeat_ngram_size=4, | |
| use_cache=True, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Clean up response | |
| if "gpt:" in response: | |
| response = response.split("gpt:")[-1].strip() | |
| if "human:" in response: | |
| response = response.split("human:")[0].strip() | |
| if "<image>" in response: | |
| response = response.replace("<image>", "").strip() | |
| self.history.append((message, response)) | |
| return response | |
| except Exception as e: | |
| logging.error(f"Error generating response: {str(e)}") | |
| logging.error(f"Full traceback:", exc_info=True) | |
| return f"Error: {str(e)}" | |
| def clear_history(self): | |
| self.history = [] | |
| return None | |
| def create_demo(): | |
| try: | |
| model = LLaVAPhiModel() | |
| with gr.Blocks(css="footer {visibility: hidden}") as demo: | |
| gr.Markdown( | |
| """ | |
| # LLaVA-Phi Demo (ZeroGPU) | |
| Chat with a vision-language model that can understand both text and images. | |
| """ | |
| ) | |
| chatbot = gr.Chatbot(height=400) | |
| with gr.Row(): | |
| with gr.Column(scale=0.7): | |
| msg = gr.Textbox( | |
| show_label=False, | |
| placeholder="Enter text and/or upload an image", | |
| container=False | |
| ) | |
| with gr.Column(scale=0.15, min_width=0): | |
| clear = gr.Button("Clear") | |
| with gr.Column(scale=0.15, min_width=0): | |
| submit = gr.Button("Submit", variant="primary") | |
| image = gr.Image(type="pil", label="Upload Image (Optional)") | |
| def respond(message, chat_history, image): | |
| if not message and image is None: | |
| return chat_history | |
| response = model.generate_response(message, image) | |
| chat_history.append((message, response)) | |
| return "", chat_history | |
| def clear_chat(): | |
| model.clear_history() | |
| return None, None | |
| submit.click( | |
| respond, | |
| [msg, chatbot, image], | |
| [msg, chatbot], | |
| ) | |
| clear.click( | |
| clear_chat, | |
| None, | |
| [chatbot, image], | |
| ) | |
| msg.submit( | |
| respond, | |
| [msg, chatbot, image], | |
| [msg, chatbot], | |
| ) | |
| return demo | |
| except Exception as e: | |
| logging.error(f"Error creating demo: {str(e)}") | |
| raise | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True | |
| ) |