| import time |
| import json |
| import uuid |
| import torch |
| from threading import Thread, Event |
| from fastapi import FastAPI, Request |
| from fastapi.responses import StreamingResponse |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| TextIteratorStreamer, |
| LogitsProcessor, |
| LogitsProcessorList, |
| StoppingCriteria, |
| StoppingCriteriaList, |
| ) |
|
|
| |
| |
| |
| MODEL_ID = "/workspace/output/glm4_7_30b/hf_temp_07i" |
| VIEW_NAME = "RWKV-GLM-4.7-Flash" |
| HOST = "0.0.0.0" |
| PORT = 8000 |
|
|
| |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
|
|
| app = FastAPI() |
|
|
|
|
| |
| |
| |
| class PresencePenaltyProcessor(LogitsProcessor): |
| def __init__(self, penalty): |
| self.penalty = penalty |
|
|
| def __call__(self, input_ids, scores): |
| for batch_idx in range(input_ids.shape[0]): |
| unique_tokens = torch.unique(input_ids[batch_idx]) |
| scores[batch_idx, unique_tokens] -= self.penalty |
| return scores |
|
|
|
|
| class FrequencyPenaltyProcessor(LogitsProcessor): |
| def __init__(self, penalty): |
| self.penalty = penalty |
|
|
| def __call__(self, input_ids, scores): |
| for batch_idx in range(input_ids.shape[0]): |
| token_counts = torch.bincount( |
| input_ids[batch_idx], minlength=scores.shape[-1] |
| ) |
| scores[batch_idx] -= token_counts * self.penalty |
| return scores |
|
|
|
|
| |
| |
| |
| class CancelledStoppingCriteria(StoppingCriteria): |
| """threading.Event がセットされたら生成を打ち切る""" |
|
|
| def __init__(self, stop_event: Event): |
| self.stop_event = stop_event |
|
|
| def __call__(self, input_ids, scores, **kwargs): |
| return self.stop_event.is_set() |
|
|
|
|
| |
| |
| |
| @app.get("/v1/models") |
| async def list_models(): |
| return { |
| "object": "list", |
| "data": [ |
| { |
| "id": VIEW_NAME, |
| "object": "model", |
| "created": int(time.time()), |
| "owned_by": "local", |
| } |
| ], |
| } |
|
|
|
|
| |
| |
| |
| @app.post("/v1/chat/completions") |
| async def chat_completions(request: Request): |
| body = await request.json() |
|
|
| model_name = body.get("model", MODEL_ID) |
| messages = body["messages"] |
| stream = body.get("stream", False) |
|
|
| temperature = body.get("temperature", 1.0) |
| top_p = body.get("top_p", 1.0) |
| top_k = body.get("top_k", 50) |
| repetition_penalty = body.get("repetition_penalty", 1.0) |
| presence_penalty = body.get("presence_penalty", 0.0) |
| frequency_penalty = body.get("frequency_penalty", 0.0) |
| max_tokens = body.get("max_tokens", 2048) |
|
|
| prompt = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
| processors = LogitsProcessorList() |
| if presence_penalty > 0: |
| processors.append(PresencePenaltyProcessor(presence_penalty)) |
| if frequency_penalty > 0: |
| processors.append(FrequencyPenaltyProcessor(frequency_penalty)) |
|
|
| generate_kwargs = dict( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| repetition_penalty=repetition_penalty, |
| logits_processor=processors, |
| do_sample=temperature > 0, |
| use_cache=True, |
| ) |
|
|
| |
| if not stream: |
| outputs = model.generate(**generate_kwargs) |
| completion_tokens = outputs.shape[1] - inputs["input_ids"].shape[1] |
| generated_text = tokenizer.decode( |
| outputs[0][inputs["input_ids"].shape[1] :], skip_special_tokens=False |
| ) |
| return { |
| "id": f"chatcmpl-{uuid.uuid4().hex}", |
| "object": "chat.completion", |
| "created": int(time.time()), |
| "model": model_name, |
| "choices": [ |
| { |
| "index": 0, |
| "message": {"role": "assistant", "content": generated_text}, |
| "finish_reason": "stop", |
| } |
| ], |
| "usage": { |
| "prompt_tokens": inputs["input_ids"].shape[1], |
| "completion_tokens": completion_tokens, |
| "total_tokens": inputs["input_ids"].shape[1] + completion_tokens, |
| }, |
| } |
|
|
| |
| stop_event = Event() |
|
|
| stopping_criteria = StoppingCriteriaList( |
| [CancelledStoppingCriteria(stop_event)] |
| ) |
|
|
| streamer = TextIteratorStreamer( |
| tokenizer, skip_prompt=True, skip_special_tokens=True |
| ) |
|
|
| generation_kwargs = dict( |
| **generate_kwargs, |
| streamer=streamer, |
| stopping_criteria=stopping_criteria, |
| ) |
|
|
| thread = Thread(target=model.generate, kwargs=generation_kwargs) |
| thread.start() |
|
|
| async def event_generator(): |
| completion_id = f"chatcmpl-{uuid.uuid4().hex}" |
| firsttime = "<think>" |
| cancelled = False |
|
|
| try: |
| for new_text in streamer: |
| if await request.is_disconnected(): |
| stop_event.set() |
| cancelled = True |
| break |
|
|
| chunk = { |
| "id": completion_id, |
| "object": "chat.completion.chunk", |
| "created": int(time.time()), |
| "model": model_name, |
| "choices": [ |
| { |
| "index": 0, |
| "delta": {"content": firsttime + new_text}, |
| "finish_reason": None, |
| } |
| ], |
| } |
| firsttime = "" |
| yield f"data: {json.dumps(chunk)}\n\n" |
|
|
| if not cancelled: |
| yield "data: [DONE]\n\n" |
|
|
| except Exception: |
| stop_event.set() |
| cancelled = True |
| finally: |
| if cancelled: |
| for _ in streamer: |
| pass |
| thread.join(timeout=10) |
|
|
| return StreamingResponse( |
| event_generator(), media_type="text/event-stream" |
| ) |
|
|
|
|
| |
| |
| |
| if __name__ == "__main__": |
| import uvicorn |
|
|
| uvicorn.run( |
| "test_openai_api:app", |
| host=HOST, |
| port=PORT, |
| reload=False, |
| ) |