Spaces:
Sleeping
Sleeping
| import eventlet | |
| eventlet.monkey_patch(socket=True, select=True, thread=True) | |
| import eventlet.wsgi | |
| from flask import Flask, render_template, request | |
| from flask_socketio import SocketIO | |
| from transformers import MultiBeamTextStreamer, AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| app = Flask(__name__) | |
| socketio = SocketIO( | |
| app, | |
| async_mode='eventlet', | |
| message_queue=None, | |
| ping_timeout=60, | |
| ping_interval=25, | |
| cors_allowed_origins="*", | |
| logger=True, | |
| engineio_logger=True, | |
| async_handlers=True | |
| ) | |
| # Initialize models and tokenizers | |
| MODELS = { | |
| "qwen": { | |
| "name": "Qwen/Qwen2.5-0.5B-Instruct", | |
| "tokenizer": None, | |
| "model": None, | |
| "uses_chat_template": True # Qwen uses chat template | |
| }, | |
| "gpt2": { | |
| "name": "gpt2", | |
| "tokenizer": None, | |
| "model": None, | |
| "uses_chat_template": False # GPT2 doesn't use chat template | |
| } | |
| } | |
| # Load models and tokenizers | |
| for model_key, model_info in MODELS.items(): | |
| model_info["tokenizer"] = AutoTokenizer.from_pretrained(model_info["name"]) | |
| model_info["model"] = AutoModelForCausalLM.from_pretrained( | |
| model_info["name"], | |
| torch_dtype="auto", | |
| device_map="auto" | |
| ) | |
| # Add pad token for GPT2 if it doesn't have one | |
| if model_key == "gpt2" and model_info["tokenizer"].pad_token is None: | |
| model_info["tokenizer"].pad_token = model_info["tokenizer"].eos_token | |
| model_info["model"].config.pad_token_id = model_info["model"].config.eos_token_id | |
| class WebSocketBeamStreamer(MultiBeamTextStreamer): | |
| """Custom streamer that sends updates through websockets with adjustable speed""" | |
| def __init__(self, tokenizer, num_beams, sleep_time=0, skip_prompt=True): | |
| super().__init__( | |
| tokenizer, | |
| num_beams=num_beams, | |
| skip_prompt=skip_prompt, | |
| on_beam_update=self.on_beam_update, | |
| on_beam_finished=self.on_beam_finished | |
| ) | |
| self.beam_texts = {i: "" for i in range(num_beams)} | |
| self.sleep_time = sleep_time | |
| def on_beam_update(self, beam_idx: int, new_text: str): | |
| self.beam_texts[beam_idx] = new_text | |
| if self.sleep_time > 0: | |
| eventlet.sleep(self.sleep_time / 1000) | |
| socketio.emit('beam_update', { | |
| 'beam_idx': beam_idx, | |
| 'text': new_text | |
| }, namespace='/', callback=lambda: eventlet.sleep(0)) | |
| socketio.sleep(0) | |
| def on_beam_finished(self, final_text: str): | |
| socketio.emit('beam_finished', { | |
| 'text': final_text | |
| }) | |
| def index(): | |
| return render_template('index.html') | |
| def handle_generation(data): | |
| socketio.emit('generation_started') | |
| prompt = data['prompt'] | |
| model_name = data.get('model', 'qwen') # Default to qwen if not specified | |
| num_beams = data.get('num_beams', 5) | |
| max_new_tokens = data.get('max_tokens', 512) | |
| sleep_time = data.get('sleep_time', 0) | |
| # Get the selected model info | |
| model_info = MODELS[model_name] | |
| model = model_info["model"] | |
| tokenizer = model_info["tokenizer"] | |
| # Prepare input text based on model type | |
| if model_info["uses_chat_template"]: | |
| # For Qwen, use chat template | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| else: | |
| # For GPT2, use direct prompt | |
| text = prompt | |
| # Prepare inputs | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| # Initialize streamer | |
| streamer = WebSocketBeamStreamer( | |
| tokenizer=tokenizer, | |
| num_beams=num_beams, | |
| sleep_time=sleep_time, | |
| skip_prompt=True | |
| ) | |
| try: | |
| # Generate with beam search | |
| with torch.no_grad(): | |
| model.generate( | |
| **model_inputs, | |
| num_beams=num_beams, | |
| num_return_sequences=num_beams, | |
| max_new_tokens=max_new_tokens, | |
| output_scores=True, | |
| return_dict_in_generate=True, | |
| early_stopping=True, | |
| streamer=streamer, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| except Exception as e: | |
| socketio.emit('generation_error', {'error': str(e)}) | |
| finally: | |
| socketio.emit('generation_completed') |