Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from PIL import Image | |
| import numpy as np | |
| import warnings | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from gradio.routes import mount_gradio_app | |
| import gradio as gr | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| AutoProcessor, | |
| ) | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore", category=UserWarning, module="gradio.analytics") | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| # Force CPU Only | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
| torch.cuda.is_available = lambda: False | |
| device = "cpu" | |
| print("Running on CPU β ") | |
| # ---------------- LOAD CHAT MODEL ---------------- | |
| MODEL_ID = "microsoft/Phi-3.5-mini-instruct" | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| # Add padding token if it doesn't exist | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float32, # Changed from deprecated torch_dtype | |
| device_map="cpu", | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True | |
| ).eval() | |
| print("Chat model loaded β ") | |
| except Exception as e: | |
| print(f"Chat model failed to load: {e}") | |
| raise | |
| # ---------------- LOAD VISION MODEL ---------------- | |
| models = {} | |
| processors = {} | |
| try: | |
| VISION_ID = "" | |
| # Disable flash attention to avoid the error | |
| models[VISION_ID] = AutoModelForCausalLM.from_pretrained( | |
| VISION_ID, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float32, # Changed from deprecated torch_dtype | |
| device_map="cpu", | |
| low_cpu_mem_usage=True, | |
| attn_implementation="eager", # Force eager attention | |
| _attn_implementation_internal="eager" # Additional parameter for compatibility | |
| ).eval() | |
| processors[VISION_ID] = AutoProcessor.from_pretrained( | |
| VISION_ID, | |
| trust_remote_code=True | |
| ) | |
| print("Vision model loaded β ") | |
| except Exception as e: | |
| print(f"Vision model failed to load: {e}") | |
| # Don't raise here to allow the app to run without vision capabilities | |
| # ---------------- CHAT FUNCTION ---------------- | |
| def chat_simple(message, history): | |
| try: | |
| conversation = [{"role": "system", "content": "You are a helpful assistant."}] | |
| for user, assistant in history: | |
| conversation.append({"role": "user", "content": user}) | |
| conversation.append({"role": "assistant", "content": assistant}) | |
| conversation.append({"role": "user", "content": message}) | |
| input_ids = tokenizer.apply_chat_template( | |
| conversation, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| padding=True, # Added padding for stability | |
| truncation=True # Added truncation for long conversations | |
| ) | |
| with torch.no_grad(): # Added for efficiency | |
| output = model.generate( | |
| input_ids, | |
| max_new_tokens=256, | |
| pad_token_id=tokenizer.pad_token_id, | |
| do_sample=False, | |
| temperature=0.7, | |
| use_cache=False | |
| ) | |
| reply = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True) | |
| return reply.strip() | |
| except Exception as e: | |
| return f"Error in chat: {str(e)}" | |
| # ---------------- VISION FUNCTION ---------------- | |
| def run_vision(image, text_input, model_id): | |
| if not image: | |
| return "β οΈ Please upload an image first." | |
| if model_id not in models: | |
| return "β οΈ Vision model not loaded." | |
| try: | |
| model_vision = models[model_id] | |
| processor = processors[model_id] | |
| if isinstance(image, np.ndarray): | |
| img = Image.fromarray(image).convert("RGB") | |
| else: | |
| img = image.convert("RGB") if hasattr(image, 'convert') else Image.open(image).convert("RGB") | |
| placeholder = "<|image_1|>\n" | |
| prompt = placeholder + (text_input or "Describe this image") | |
| messages = [{"role": "user", "content": prompt}] | |
| template = processor.tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = processor(template, [img], return_tensors="pt") | |
| with torch.no_grad(): | |
| output = model_vision.generate( | |
| **inputs, | |
| max_new_tokens=400, | |
| do_sample=False, | |
| pad_token_id=processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id, | |
| temperature=0.7, | |
| use_cache=False | |
| ) | |
| output = output[:, inputs["input_ids"].shape[1]:] | |
| response = processor.batch_decode(output, skip_special_tokens=True)[0] | |
| return response.strip() | |
| except Exception as e: | |
| return f"Error in vision processing: {str(e)}" | |
| # ---------------- FASTAPI BACKEND ---------------- | |
| api = FastAPI(title="Phi-3.5 AI Assistant", version="1.0.0") | |
| async def root(): | |
| return {"message": "Phi-3.5 AI Assistant API", "status": "running"} | |
| async def health(): | |
| return { | |
| "status": "ok", | |
| "device": device, | |
| "vision_loaded": len(models) > 0, | |
| "models_available": list(models.keys()) | |
| } | |
| async def api_chat(message: str = Form(...)): | |
| try: | |
| if not message.strip(): | |
| raise HTTPException(status_code=400, detail="Message cannot be empty") | |
| conversation = [{"role": "user", "content": message}] | |
| input_ids = tokenizer.apply_chat_template( | |
| conversation, | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ) | |
| with torch.no_grad(): | |
| output = model.generate( | |
| input_ids, | |
| max_new_tokens=256, | |
| pad_token_id=tokenizer.pad_token_id, | |
| use_cache=False | |
| ) | |
| reply = tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True) | |
| return {"response": reply.strip()} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}") | |
| async def api_vision( | |
| image: UploadFile = File(...), | |
| text_input: str = Form("Describe this image"), | |
| model_id: str = Form("microsoft/Phi-3.5-vision-instruct") | |
| ): | |
| try: | |
| if not image.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| if model_id not in models: | |
| raise HTTPException(status_code=400, detail="Vision model not available") | |
| # Read and process image | |
| image_data = await image.read() | |
| img = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| response = run_vision(np.array(img), text_input, model_id) | |
| return {"response": response} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Vision processing error: {str(e)}") | |
| # ---------------- GRADIO UI ---------------- | |
| def create_ui(): | |
| with gr.Blocks(title="Phi-3.5 AI Assistant", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π Phi-3.5 AI Assistant") | |
| with gr.Tab("π¬ Chat"): | |
| gr.Markdown("### Chat with Phi-3.5 Mini") | |
| gr.ChatInterface( | |
| fn=chat_simple, | |
| title="Phi-3.5 Mini Chat", | |
| description="Ask me anything! I'm here to help." | |
| ) | |
| with gr.Tab("ποΈ Vision"): | |
| gr.Markdown("### Vision Analysis with Phi-3.5 Vision") | |
| with gr.Row(): | |
| with gr.Column(): | |
| img = gr.Image( | |
| label="Upload Image", | |
| type="numpy", | |
| height=300 | |
| ) | |
| txt = gr.Textbox( | |
| label="Prompt", | |
| value="What's in this image?", | |
| placeholder="Describe what you see in the image..." | |
| ) | |
| model_sel = gr.Dropdown( | |
| choices=list(models.keys()), | |
| value=list(models.keys())[0] if models else None, | |
| label="Model", | |
| interactive=len(models) > 1 | |
| ) | |
| analyze_btn = gr.Button("π Analyze", variant="primary") | |
| with gr.Column(): | |
| out = gr.Textbox( | |
| label="Analysis Result", | |
| placeholder="Results will appear here...", | |
| lines=6 | |
| ) | |
| examples = gr.Examples( | |
| examples=[ | |
| ["What's in this image?", "microsoft/Phi-3.5-vision-instruct"], | |
| ["Describe this image in detail", "microsoft/Phi-3.5-vision-instruct"] | |
| ], | |
| inputs=[txt, model_sel], | |
| label="Example Prompts" | |
| ) | |
| analyze_btn.click( | |
| run_vision, | |
| inputs=[img, txt, model_sel], | |
| outputs=out | |
| ) | |
| with gr.Tab("βΉοΈ System Info"): | |
| gr.Markdown("### System Information") | |
| gr.JSON(value={ | |
| "device": device, | |
| "vision_loaded": len(models) > 0, | |
| "available_models": list(models.keys()), | |
| "chat_model": MODEL_ID | |
| }) | |
| return demo | |
| # Import required for image processing | |
| import io | |
| # Create and mount Gradio app | |
| gradio_app = create_ui() | |
| app = mount_gradio_app(api, gradio_app, path="/") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |