Spaces:
Configuration error
Configuration error
| import os | |
| import time | |
| import torch | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM | |
| from langgraph.graph import StateGraph | |
| # --- Model and Workflow Setup --- | |
| model_id = "ibm-granite/granite-4.0-tiny-preview" | |
| processor = AutoProcessor.from_pretrained(model_id) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| model = AutoModelForCausalLM.from_pretrained(model_id) | |
| def generate_with_granite(prompt: str, max_tokens: int = 200, use_gpu: bool = False) -> str: | |
| device = torch.device("cuda" if use_gpu and torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| messages = [{"role": "user", "content": prompt}] | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt" | |
| ).to(device) | |
| outputs = model.generate( | |
| input_ids=inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| generated = tokenizer.decode(outputs[0][inputs.shape[-1]:], skip_special_tokens=True) | |
| return generated.strip() | |
| def select_genre_node(state: dict) -> dict: | |
| prompt = f""" | |
| You are a creative assistant. The user wants to write a short animated story. | |
| Based on the following input, suggest a suitable genre and tone for the story. | |
| User Input: {state['user_input']} | |
| Respond in this format: | |
| Genre: <genre> | |
| Tone: <tone> | |
| """.strip() | |
| response = generate_with_granite(prompt) | |
| genre, tone = None, None | |
| for line in response.splitlines(): | |
| if "Genre:" in line: | |
| genre = line.split("Genre:")[1].strip() | |
| elif "Tone:" in line: | |
| tone = line.split("Tone:")[1].strip() | |
| state["genre"] = genre | |
| state["tone"] = tone | |
| return state | |
| def generate_outline_node(state: dict) -> dict: | |
| prompt = f""" | |
| You are a creative writing assistant helping to write a short animated screenplay. | |
| The user wants to write a story with the following details: | |
| Genre: {state.get('genre')} | |
| Tone: {state.get('tone')} | |
| Idea: {state.get('user_input')} | |
| Write a brief plot outline (3–5 sentences) for the story. | |
| """.strip() | |
| response = generate_with_granite(prompt, max_tokens=250) | |
| state["outline"] = response | |
| return state | |
| def generate_scene_node(state: dict) -> dict: | |
| prompt = f""" | |
| You are a screenwriter. | |
| Based on the following plot outline, write a key scene from the story. | |
| Focus on a turning point or climax moment. Make the scene vivid, descriptive, and suitable for an animated short film. | |
| Genre: {state.get('genre')} | |
| Tone: {state.get('tone')} | |
| Outline: {state.get('outline')} | |
| Write the scene in prose format (not screenplay format). | |
| """.strip() | |
| response = generate_with_granite(prompt, max_tokens=300) | |
| state["scene"] = response | |
| return state | |
| def write_dialogue_node(state: dict) -> dict: | |
| prompt = f""" | |
| You are a dialogue writer for an animated screenplay. | |
| Below is a scene from the story: | |
| {state.get('scene')} | |
| Write the dialogue between the characters in screenplay format. | |
| Keep it short, expressive, and suitable for a short animated film. | |
| Use character names (you may invent them if needed), and format as: | |
| CHARACTER: | |
| Dialogue line | |
| CHARACTER: | |
| Dialogue line | |
| """.strip() | |
| response = generate_with_granite(prompt, max_tokens=100) | |
| state["dialogue"] = response | |
| return state | |
| def with_progress(fn, label, index, total): | |
| def wrapper(state): | |
| print(f"\n[{index}/{total}] Starting: {label}") | |
| start = time.time() | |
| result = fn(state) | |
| duration = time.time() - start | |
| print(f"[{index}/{total}] Completed: {label} in {duration:.2f} seconds") | |
| return result | |
| return wrapper | |
| def build_workflow(): | |
| graph = StateGraph(dict) | |
| graph.add_node("select_genre", with_progress(select_genre_node, "Select Genre", 1, 4)) | |
| graph.add_node("generate_outline", with_progress(generate_outline_node, "Generate Outline", 2, 4)) | |
| graph.add_node("generate_scene", with_progress(generate_scene_node, "Generate Scene", 3, 4)) | |
| graph.add_node("write_dialogue", with_progress(write_dialogue_node, "Write Dialogue", 4, 4)) | |
| graph.set_entry_point("select_genre") | |
| graph.add_edge("select_genre", "generate_outline") | |
| graph.add_edge("generate_outline", "generate_scene") | |
| graph.add_edge("generate_scene", "write_dialogue") | |
| graph.set_finish_point("write_dialogue") | |
| return graph.compile() | |
| workflow = build_workflow() | |
| # --- Flask App --- | |
| app = Flask(__name__) | |
| CORS(app) | |
| def generate_story(): | |
| data = request.get_json() | |
| user_input = data.get("user_input") | |
| if not user_input: | |
| return jsonify({"error": "Missing 'user_input' in request."}), 400 | |
| initial_state = {"user_input": user_input} | |
| final_state = workflow.invoke(initial_state) | |
| return jsonify({ | |
| "genre": final_state.get("genre"), | |
| "tone": final_state.get("tone"), | |
| "outline": final_state.get("outline"), | |
| "scene": final_state.get("scene"), | |
| "dialogue": final_state.get("dialogue") | |
| }) | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 8000)), debug=True) |