from flask import Flask, request, Response, stream_with_context from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread import torch app = Flask(__name__) model_id = "google/gemma-3-1b-it" # Using the official IT model tokenizer = AutoTokenizer.from_pretrained(model_id) # Load in 4-bit to fit easily and run faster on CPU model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", low_cpu_mem_usage=True, load_in_4bit=True ) @app.route('/generate', methods=['POST']) def generate(): data = request.json prompt = data.get("prompt", "") # Format for Gemma 3 messages = [ {"role": "system", "content": "You are Jarvis. Be concise."}, {"role": "user", "content": prompt} ] inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # Run generation in a separate thread so we can yield tokens immediately generation_kwargs = dict(input_ids=inputs, streamer=streamer, max_new_tokens=128) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() @stream_with_context def stream_words(): for new_text in streamer: yield new_text return Response(stream_words(), mimetype='text/plain') if __name__ == "__main__": app.run(host="0.0.0.0", port=7860)