Spaces:
Sleeping
Sleeping
| """ | |
| Lightweight Multi-Model AI Backend for Hugging Face Gradio Space | |
| Optimized for FREE CPU tier - No GPU required | |
| """ | |
| import gc | |
| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from PIL import Image, ImageDraw | |
| import numpy as np | |
| import base64 | |
| from io import BytesIO | |
| # ===== DEVICE CONFIGURATION ===== | |
| device = "cpu" | |
| torch.set_num_threads(4) | |
| os.environ['TOKENIZERS_PARALLELISM'] = 'false' | |
| # ===== MODEL MANAGER ===== | |
| class ModelManager: | |
| def __init__(self): | |
| self.chat_model = None | |
| self.chat_tokenizer = None | |
| self.summarizer_pipeline = None | |
| def load_chat_model(self): | |
| if self.chat_model is None: | |
| print("Loading TinyLlama...") | |
| model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| self.chat_tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.chat_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, torch_dtype=torch.float32, device_map=device, low_cpu_mem_usage=True | |
| ) | |
| self.chat_model.eval() | |
| gc.collect() | |
| return self.chat_model, self.chat_tokenizer | |
| def load_summarizer(self): | |
| if self.summarizer_pipeline is None: | |
| print("Loading FLAN-T5...") | |
| self.summarizer_pipeline = pipeline( | |
| "summarization", model="google/flan-t5-small", framework="pt", device=-1 | |
| ) | |
| gc.collect() | |
| return self.summarizer_pipeline | |
| model_manager = ModelManager() | |
| # ===== GENERATION FUNCTIONS ===== | |
| def chat_fn(prompt, max_tokens, temperature): | |
| try: | |
| max_tokens = min(int(max_tokens), 200) | |
| model, tokenizer = model_manager.load_chat_model() | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, max_new_tokens=max_tokens, temperature=temperature, | |
| top_p=0.9, do_sample=True, pad_token_id=tokenizer.eos_token_id | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| del inputs, outputs | |
| gc.collect() | |
| return response | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def code_fn(prompt, max_tokens, temperature): | |
| try: | |
| max_tokens = min(int(max_tokens), 300) | |
| model, tokenizer = model_manager.load_chat_model() | |
| code_prompt = f"Generate Python code: {prompt}" | |
| inputs = tokenizer(code_prompt, return_tensors="pt", truncation=True, max_length=512) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, max_new_tokens=max_tokens, temperature=max(temperature, 0.1), | |
| top_p=0.95, do_sample=True, pad_token_id=tokenizer.eos_token_id | |
| ) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| del inputs, outputs | |
| gc.collect() | |
| return response | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def summarize_fn(text, max_length): | |
| try: | |
| if len(text.strip()) < 50: | |
| return "Text too short (min 50 chars)" | |
| text = text[:1000] if len(text) > 1000 else text | |
| summarizer = model_manager.load_summarizer() | |
| summary = summarizer(text, max_length=min(int(max_length), 150), min_length=20, do_sample=False) | |
| gc.collect() | |
| return summary[0]['summary_text'] | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def image_fn(prompt, width, height): | |
| try: | |
| width, height = min(int(width), 256), min(int(height), 256) | |
| seed = abs(hash(prompt)) % (2**32) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| img = Image.new('RGB', (width, height), color=(255, 255, 255)) | |
| pixels = img.load() | |
| for y in range(height): | |
| for x in range(width): | |
| r = int((np.sin(x / 50 + seed) * 127) + 128) | |
| g = int((np.cos(y / 50 + seed * 0.5) * 127) + 128) | |
| b = int((np.sin((x + y) / 100 + seed * 0.7) * 127) + 128) | |
| pixels[x, y] = (r, g, b) | |
| buffered = BytesIO() | |
| img.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return f"data:image/png;base64,{img_str}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # ===== GRADIO INTERFACE ===== | |
| # Create individual interfaces | |
| chat_demo = gr.Interface( | |
| fn=chat_fn, | |
| inputs=[ | |
| gr.Textbox(lines=3, label="Message"), | |
| gr.Slider(50, 200, 150, step=10, label="Max Tokens"), | |
| gr.Slider(0.1, 1.0, 0.7, step=0.1, label="Temperature") | |
| ], | |
| outputs=gr.Textbox(lines=10, label="Response"), | |
| title="π¬ Chat" | |
| ) | |
| code_demo = gr.Interface( | |
| fn=code_fn, | |
| inputs=[ | |
| gr.Textbox(lines=3, label="Description"), | |
| gr.Slider(100, 300, 256, step=20, label="Max Tokens"), | |
| gr.Slider(0.1, 1.0, 0.3, step=0.1, label="Temperature") | |
| ], | |
| outputs=gr.Textbox(lines=10, label="Code"), | |
| title="π» Code" | |
| ) | |
| summarize_demo = gr.Interface( | |
| fn=summarize_fn, | |
| inputs=[ | |
| gr.Textbox(lines=8, label="Text"), | |
| gr.Slider(20, 150, 100, step=10, label="Summary Length") | |
| ], | |
| outputs=gr.Textbox(lines=8, label="Summary"), | |
| title="π Summarize" | |
| ) | |
| image_demo = gr.Interface( | |
| fn=image_fn, | |
| inputs=[ | |
| gr.Textbox(label="Description"), | |
| gr.Slider(128, 256, 256, step=32, label="Width"), | |
| gr.Slider(128, 256, 256, step=32, label="Height") | |
| ], | |
| outputs=gr.Textbox(label="Image (Base64)"), | |
| title="π¨ Image" | |
| ) | |
| # Create tabbed interface | |
| demo = gr.TabbedInterface( | |
| [chat_demo, code_demo, summarize_demo, image_demo], | |
| tab_names=["π¬ Chat", "π» Code", "π Summarize", "π¨ Image"], | |
| title="π€ Lightweight AI Backend" | |
| ) | |
| # ===== INITIALIZE AND RUN ===== | |
| if __name__ == "__main__": | |
| print("=" * 60) | |
| print("π Lightweight AI Backend Starting...") | |
| print("=" * 60) | |
| print(f"Device: {device}") | |
| print(f"CPU Threads: {torch.get_num_threads()}") | |
| print("=" * 60) | |
| demo.queue(max_size=10, default_concurrency_limit=2) | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=False, show_error=True) | |